mirror of https://github.com/VLSIDA/OpenRAM.git
Convert entire OpenRAM to use python3. Works with Python 3.6.
Major changes: Remove mpmath library and use numpy instead. Convert bytes to new bytearrays. Fix class name check for duplicate gds instances. Add explicit integer conversion from floats. Fix importlib reload from importlib library Fix new key/index syntax issues. Fix filter and map conversion to lists. Fix deprecation warnings. Fix Circuits vs Netlist in Magic LVS results. Fix file closing warnings.
This commit is contained in:
parent
58628d7867
commit
f34c4eb7dc
|
|
@ -29,18 +29,18 @@ class design(hierarchy_spice.spice, hierarchy_layout.layout):
|
|||
# because each reference must be a unique name.
|
||||
# These modules ensure unique names or have no changes if they
|
||||
# aren't unique
|
||||
ok_list = ['ms_flop.ms_flop',
|
||||
'dff.dff',
|
||||
'dff_buf.dff_buf',
|
||||
'bitcell.bitcell',
|
||||
'contact.contact',
|
||||
'ptx.ptx',
|
||||
'sram.sram',
|
||||
'hierarchical_predecode2x4.hierarchical_predecode2x4',
|
||||
'hierarchical_predecode3x8.hierarchical_predecode3x8']
|
||||
ok_list = ['ms_flop',
|
||||
'dff',
|
||||
'dff_buf',
|
||||
'bitcell',
|
||||
'contact',
|
||||
'ptx',
|
||||
'sram',
|
||||
'hierarchical_predecode2x4',
|
||||
'hierarchical_predecode3x8']
|
||||
if name not in design.name_map:
|
||||
design.name_map.append(name)
|
||||
elif str(self.__class__) in ok_list:
|
||||
elif self.__class__.__name__ in ok_list:
|
||||
pass
|
||||
else:
|
||||
debug.error("Duplicate layout reference name {0} of class {1}. GDS2 requires names be unique.".format(name,self.__class__),-1)
|
||||
|
|
|
|||
|
|
@ -116,10 +116,11 @@ class spice(verilog.verilog):
|
|||
self.spice = f.readlines()
|
||||
for i in range(len(self.spice)):
|
||||
self.spice[i] = self.spice[i].rstrip(" \n")
|
||||
f.close()
|
||||
|
||||
# find the correct subckt line in the file
|
||||
subckt = re.compile("^.subckt {}".format(self.name), re.IGNORECASE)
|
||||
subckt_line = filter(subckt.search, self.spice)[0]
|
||||
subckt_line = list(filter(subckt.search, self.spice))[0]
|
||||
# parses line into ports and remove subckt
|
||||
self.pins = subckt_line.split(" ")[2:]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class pin_layout:
|
|||
self.rect = [x.snap_to_grid() for x in self.rect]
|
||||
# if it's a layer number look up the layer name. this assumes a unique layer number.
|
||||
if type(layer_name_num)==int:
|
||||
self.layer = layer.keys()[layer.values().index(layer_name_num)]
|
||||
self.layer = list(layer.keys())[list(layer.values()).index(layer_name_num)]
|
||||
else:
|
||||
self.layer=layer_name_num
|
||||
self.layer_num = layer[self.layer]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
import debug
|
||||
from globals import OPTS,find_exe,get_tool
|
||||
import lib
|
||||
import delay
|
||||
import setup_hold
|
||||
from .lib import *
|
||||
from .delay import *
|
||||
from .setup_hold import *
|
||||
|
||||
|
||||
debug.info(2,"Initializing characterizer...")
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import sys,re,shutil
|
|||
import debug
|
||||
import tech
|
||||
import math
|
||||
import stimuli
|
||||
from trim_spice import trim_spice
|
||||
import charutils as ch
|
||||
from .stimuli import *
|
||||
from .trim_spice import *
|
||||
from .charutils import *
|
||||
import utils
|
||||
from globals import OPTS
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class delay():
|
|||
self.sf.write("* Delay stimulus for period of {0}n load={1}fF slew={2}ns\n\n".format(self.period,
|
||||
self.load,
|
||||
self.slew))
|
||||
self.stim = stimuli.stimuli(self.sf, self.corner)
|
||||
self.stim = stimuli(self.sf, self.corner)
|
||||
# include files in stimulus file
|
||||
self.stim.write_include(self.trim_sp_file)
|
||||
|
||||
|
|
@ -339,16 +339,16 @@ class delay():
|
|||
# Checking from not data_value to data_value
|
||||
self.write_delay_stimulus()
|
||||
self.stim.run_sim()
|
||||
delay_hl = ch.parse_output("timing", "delay_hl")
|
||||
delay_lh = ch.parse_output("timing", "delay_lh")
|
||||
slew_hl = ch.parse_output("timing", "slew_hl")
|
||||
slew_lh = ch.parse_output("timing", "slew_lh")
|
||||
delay_hl = parse_output("timing", "delay_hl")
|
||||
delay_lh = parse_output("timing", "delay_lh")
|
||||
slew_hl = parse_output("timing", "slew_hl")
|
||||
slew_lh = parse_output("timing", "slew_lh")
|
||||
delays = (delay_hl, delay_lh, slew_hl, slew_lh)
|
||||
|
||||
read0_power=ch.parse_output("timing", "read0_power")
|
||||
write0_power=ch.parse_output("timing", "write0_power")
|
||||
read1_power=ch.parse_output("timing", "read1_power")
|
||||
write1_power=ch.parse_output("timing", "write1_power")
|
||||
read0_power=parse_output("timing", "read0_power")
|
||||
write0_power=parse_output("timing", "write0_power")
|
||||
read1_power=parse_output("timing", "read1_power")
|
||||
write1_power=parse_output("timing", "write1_power")
|
||||
|
||||
if not self.check_valid_delays(delays):
|
||||
return (False,{})
|
||||
|
|
@ -378,22 +378,24 @@ class delay():
|
|||
|
||||
self.write_power_stimulus(trim=False)
|
||||
self.stim.run_sim()
|
||||
leakage_power=ch.parse_output("timing", "leakage_power")
|
||||
leakage_power=parse_output("timing", "leakage_power")
|
||||
debug.check(leakage_power!="Failed","Could not measure leakage power.")
|
||||
|
||||
|
||||
self.write_power_stimulus(trim=True)
|
||||
self.stim.run_sim()
|
||||
trim_leakage_power=ch.parse_output("timing", "leakage_power")
|
||||
trim_leakage_power=parse_output("timing", "leakage_power")
|
||||
debug.check(trim_leakage_power!="Failed","Could not measure leakage power.")
|
||||
|
||||
# For debug, you sometimes want to inspect each simulation.
|
||||
#key=raw_input("press return to continue")
|
||||
return (leakage_power*1e3, trim_leakage_power*1e3)
|
||||
|
||||
def check_valid_delays(self, (delay_hl, delay_lh, slew_hl, slew_lh)):
|
||||
def check_valid_delays(self, delay_tuple):
|
||||
""" Check if the measurements are defined and if they are valid. """
|
||||
|
||||
(delay_hl, delay_lh, slew_hl, slew_lh) = delay_tuple
|
||||
|
||||
# if it failed or the read was longer than a period
|
||||
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
|
||||
debug.info(2,"Failed simulation: period {0} load {1} slew {2}, delay_hl={3}n delay_lh={4}ns slew_hl={5}n slew_lh={6}n".format(self.period,
|
||||
|
|
@ -457,7 +459,7 @@ class delay():
|
|||
else:
|
||||
lb_period = target_period
|
||||
|
||||
if ch.relative_compare(ub_period, lb_period, error_tolerance=0.05):
|
||||
if relative_compare(ub_period, lb_period, error_tolerance=0.05):
|
||||
# ub_period is always feasible
|
||||
return ub_period
|
||||
|
||||
|
|
@ -471,10 +473,10 @@ class delay():
|
|||
# Checking from not data_value to data_value
|
||||
self.write_delay_stimulus()
|
||||
self.stim.run_sim()
|
||||
delay_hl = ch.parse_output("timing", "delay_hl")
|
||||
delay_lh = ch.parse_output("timing", "delay_lh")
|
||||
slew_hl = ch.parse_output("timing", "slew_hl")
|
||||
slew_lh = ch.parse_output("timing", "slew_lh")
|
||||
delay_hl = parse_output("timing", "delay_hl")
|
||||
delay_lh = parse_output("timing", "delay_lh")
|
||||
slew_hl = parse_output("timing", "slew_hl")
|
||||
slew_lh = parse_output("timing", "slew_lh")
|
||||
# if it failed or the read was longer than a period
|
||||
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
|
||||
debug.info(2,"Invalid measures: Period {0}, delay_hl={1}ns, delay_lh={2}ns slew_hl={3}ns slew_lh={4}ns".format(self.period,
|
||||
|
|
@ -495,10 +497,10 @@ class delay():
|
|||
slew_lh))
|
||||
return False
|
||||
else:
|
||||
if not ch.relative_compare(delay_lh,feasible_delay_lh,error_tolerance=0.05):
|
||||
if not relative_compare(delay_lh,feasible_delay_lh,error_tolerance=0.05):
|
||||
debug.info(2,"Delay too big {0} vs {1}".format(delay_lh,feasible_delay_lh))
|
||||
return False
|
||||
elif not ch.relative_compare(delay_hl,feasible_delay_hl,error_tolerance=0.05):
|
||||
elif not relative_compare(delay_hl,feasible_delay_hl,error_tolerance=0.05):
|
||||
debug.info(2,"Delay too big {0} vs {1}".format(delay_hl,feasible_delay_hl))
|
||||
return False
|
||||
|
||||
|
|
@ -602,7 +604,7 @@ class delay():
|
|||
debug.info(1, "Min Period: {0}n with a delay of {1} / {2}".format(min_period, feasible_delay_lh, feasible_delay_hl))
|
||||
|
||||
# 4) Pack up the final measurements
|
||||
char_data["min_period"] = ch.round_time(min_period)
|
||||
char_data["min_period"] = round_time(min_period)
|
||||
|
||||
return char_data
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import os,sys,re
|
||||
import debug
|
||||
import math
|
||||
import setup_hold
|
||||
import delay
|
||||
import charutils as ch
|
||||
from .setup_hold import *
|
||||
from .delay import *
|
||||
from .charutils import *
|
||||
import tech
|
||||
import numpy as np
|
||||
from globals import OPTS
|
||||
|
|
@ -186,9 +186,9 @@ class lib:
|
|||
""" Helper function to create quoted, line wrapped array with each row of given length """
|
||||
# check that the length is a multiple or give an error!
|
||||
debug.check(len(values)%length == 0,"Values are not a multiple of the length. Cannot make a full array.")
|
||||
rounded_values = map(ch.round_time,values)
|
||||
rounded_values = list(map(round_time,values))
|
||||
split_values = [rounded_values[i:i+length] for i in range(0, len(rounded_values), length)]
|
||||
formatted_rows = map(self.create_list,split_values)
|
||||
formatted_rows = list(map(self.create_list,split_values))
|
||||
formatted_array = ",\\\n".join(formatted_rows)
|
||||
return formatted_array
|
||||
|
||||
|
|
@ -274,11 +274,11 @@ class lib:
|
|||
self.lib.write(" timing_type : setup_rising; \n")
|
||||
self.lib.write(" related_pin : \"clk\"; \n")
|
||||
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
|
||||
rounded_values = map(ch.round_time,self.times["setup_times_LH"])
|
||||
rounded_values = list(map(round_time,self.times["setup_times_LH"]))
|
||||
self.write_values(rounded_values,len(self.slews)," ")
|
||||
self.lib.write(" }\n")
|
||||
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
|
||||
rounded_values = map(ch.round_time,self.times["setup_times_HL"])
|
||||
rounded_values = list(map(round_time,self.times["setup_times_HL"]))
|
||||
self.write_values(rounded_values,len(self.slews)," ")
|
||||
self.lib.write(" }\n")
|
||||
self.lib.write(" }\n")
|
||||
|
|
@ -286,11 +286,11 @@ class lib:
|
|||
self.lib.write(" timing_type : hold_rising; \n")
|
||||
self.lib.write(" related_pin : \"clk\"; \n")
|
||||
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
|
||||
rounded_values = map(ch.round_time,self.times["hold_times_LH"])
|
||||
rounded_values = list(map(round_time,self.times["hold_times_LH"]))
|
||||
self.write_values(rounded_values,len(self.slews)," ")
|
||||
self.lib.write(" }\n")
|
||||
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
|
||||
rounded_values = map(ch.round_time,self.times["hold_times_HL"])
|
||||
rounded_values = list(map(round_time,self.times["hold_times_HL"]))
|
||||
self.write_values(rounded_values,len(self.slews)," ")
|
||||
self.lib.write(" }\n")
|
||||
self.lib.write(" }\n")
|
||||
|
|
@ -413,8 +413,8 @@ class lib:
|
|||
self.lib.write(" }\n")
|
||||
self.lib.write(" }\n")
|
||||
|
||||
min_pulse_width = ch.round_time(self.char_results["min_period"])/2.0
|
||||
min_period = ch.round_time(self.char_results["min_period"])
|
||||
min_pulse_width = round_time(self.char_results["min_period"])/2.0
|
||||
min_period = round_time(self.char_results["min_period"])
|
||||
self.lib.write(" timing(){ \n")
|
||||
self.lib.write(" timing_type :\"min_pulse_width\"; \n")
|
||||
self.lib.write(" related_pin : clk; \n")
|
||||
|
|
@ -443,7 +443,7 @@ class lib:
|
|||
try:
|
||||
self.d
|
||||
except AttributeError:
|
||||
self.d = delay.delay(self.sram, self.sp_file, self.corner)
|
||||
self.d = delay(self.sram, self.sp_file, self.corner)
|
||||
if self.use_model:
|
||||
self.char_results = self.d.analytical_delay(self.sram,self.slews,self.loads)
|
||||
else:
|
||||
|
|
@ -458,7 +458,7 @@ class lib:
|
|||
try:
|
||||
self.sh
|
||||
except AttributeError:
|
||||
self.sh = setup_hold.setup_hold(self.corner)
|
||||
self.sh = setup_hold(self.corner)
|
||||
if self.use_model:
|
||||
self.times = self.sh.analytical_setuphold(self.slews,self.loads)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import sys
|
||||
import tech
|
||||
import stimuli
|
||||
from .stimuli import *
|
||||
import debug
|
||||
import charutils as ch
|
||||
from .charutils import *
|
||||
import ms_flop
|
||||
from globals import OPTS
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class setup_hold():
|
|||
# creates and opens the stimulus file for writing
|
||||
temp_stim = OPTS.openram_temp + "stim.sp"
|
||||
self.sf = open(temp_stim, "w")
|
||||
self.stim = stimuli.stimuli(self.sf, self.corner)
|
||||
self.stim = stimuli(self.sf, self.corner)
|
||||
|
||||
self.write_header(correct_value)
|
||||
|
||||
|
|
@ -186,8 +186,8 @@ class setup_hold():
|
|||
target_time=feasible_bound,
|
||||
correct_value=correct_value)
|
||||
self.stim.run_sim()
|
||||
ideal_clk_to_q = ch.convert_to_float(ch.parse_output("timing", "clk2q_delay"))
|
||||
setuphold_time = ch.convert_to_float(ch.parse_output("timing", "setup_hold_time"))
|
||||
ideal_clk_to_q = convert_to_float(parse_output("timing", "clk2q_delay"))
|
||||
setuphold_time = convert_to_float(parse_output("timing", "setup_hold_time"))
|
||||
debug.info(2,"*** {0} CHECK: {1} Ideal Clk-to-Q: {2} Setup/Hold: {3}".format(mode, correct_value,ideal_clk_to_q,setuphold_time))
|
||||
|
||||
if type(ideal_clk_to_q)!=float or type(setuphold_time)!=float:
|
||||
|
|
@ -219,8 +219,8 @@ class setup_hold():
|
|||
|
||||
|
||||
self.stim.run_sim()
|
||||
clk_to_q = ch.convert_to_float(ch.parse_output("timing", "clk2q_delay"))
|
||||
setuphold_time = ch.convert_to_float(ch.parse_output("timing", "setup_hold_time"))
|
||||
clk_to_q = convert_to_float(parse_output("timing", "clk2q_delay"))
|
||||
setuphold_time = convert_to_float(parse_output("timing", "setup_hold_time"))
|
||||
if type(clk_to_q)==float and (clk_to_q<1.1*ideal_clk_to_q) and type(setuphold_time)==float:
|
||||
if mode == "SETUP": # SETUP is clk-din, not din-clk
|
||||
setuphold_time *= -1e9
|
||||
|
|
@ -235,7 +235,7 @@ class setup_hold():
|
|||
infeasible_bound = target_time
|
||||
|
||||
#raw_input("Press Enter to continue...")
|
||||
if ch.relative_compare(feasible_bound, infeasible_bound, error_tolerance=0.001):
|
||||
if relative_compare(feasible_bound, infeasible_bound, error_tolerance=0.001):
|
||||
debug.info(3,"CONVERGE {0} vs {1}".format(feasible_bound,infeasible_bound))
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ Python GDS Mill Package
|
|||
GDS Mill is a Python package for the creation and manipulation of binary GDS2 layout files.
|
||||
"""
|
||||
|
||||
from gds2reader import *
|
||||
from gds2writer import *
|
||||
from pdfLayout import *
|
||||
from vlsiLayout import *
|
||||
from gdsStreamer import *
|
||||
from gdsPrimitives import *
|
||||
from .gds2reader import *
|
||||
from .gds2writer import *
|
||||
#from .pdfLayout import *
|
||||
from .vlsiLayout import *
|
||||
from .gdsStreamer import *
|
||||
from .gdsPrimitives import *
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import struct
|
||||
from gdsPrimitives import *
|
||||
from .gdsPrimitives import *
|
||||
|
||||
class Gds2writer:
|
||||
"""Class to take a populated layout class and write it to a file in GDSII format"""
|
||||
|
|
@ -14,8 +14,8 @@ class Gds2writer:
|
|||
def print64AsBinary(self,number):
|
||||
#debugging method for binary inspection
|
||||
for index in range(0,64):
|
||||
print (number>>(63-index))&0x1,
|
||||
print "\n"
|
||||
print((number>>(63-index))&0x1,eol='')
|
||||
print("\n")
|
||||
|
||||
def ieeeDoubleFromIbmData(self,ibmData):
|
||||
#the GDS double is in IBM 370 format like this:
|
||||
|
|
@ -40,9 +40,9 @@ class Gds2writer:
|
|||
exponent-=1
|
||||
#check for underflow error -- should handle these properly!
|
||||
if(exponent<=0):
|
||||
print "Underflow Error"
|
||||
print("Underflow Error")
|
||||
elif(exponent == 2047):
|
||||
print "Overflow Error"
|
||||
print("Overflow Error")
|
||||
#re assemble
|
||||
newFloat=(sign<<63)|(exponent<<52)|((mantissa>>12)&0xfffffffffffff)
|
||||
asciiDouble = struct.pack('>q',newFloat)
|
||||
|
|
@ -84,12 +84,12 @@ class Gds2writer:
|
|||
data = struct.unpack('>q',asciiDouble)[0]
|
||||
sign = data >> 63
|
||||
exponent = ((data >> 52) & 0x7ff)-1023
|
||||
print exponent+1023
|
||||
print(exponent+1023)
|
||||
mantissa = data << 12 #chop off sign and exponent
|
||||
#self.print64AsBinary((sign<<63)|((exponent+1023)<<52)|(mantissa>>12))
|
||||
asciiDouble = struct.pack('>q',(sign<<63)|(exponent+1023<<52)|(mantissa>>12))
|
||||
newFloat = struct.unpack('>d',asciiDouble)[0]
|
||||
print "Check:"+str(newFloat)
|
||||
print("Check:"+str(newFloat))
|
||||
|
||||
def writeRecord(self,record):
|
||||
recordLength = len(record)+2 #make sure to include this in the length
|
||||
|
|
@ -99,12 +99,12 @@ class Gds2writer:
|
|||
def writeHeader(self):
|
||||
## Header
|
||||
if("gdsVersion" in self.layoutObject.info):
|
||||
idBits='\x00\x02'
|
||||
idBits=b'\x00\x02'
|
||||
gdsVersion = struct.pack(">h",self.layoutObject.info["gdsVersion"])
|
||||
self.writeRecord(idBits+gdsVersion)
|
||||
## Modified Date
|
||||
if("dates" in self.layoutObject.info):
|
||||
idBits='\x01\x02'
|
||||
idBits=b'\x01\x02'
|
||||
modYear = struct.pack(">h",self.layoutObject.info["dates"][0])
|
||||
modMonth = struct.pack(">h",self.layoutObject.info["dates"][1])
|
||||
modDay = struct.pack(">h",self.layoutObject.info["dates"][2])
|
||||
|
|
@ -122,43 +122,43 @@ class Gds2writer:
|
|||
lastAccessMinute+lastAccessSecond)
|
||||
## LibraryName
|
||||
if("libraryName" in self.layoutObject.info):
|
||||
idBits='\x02\x06'
|
||||
idBits=b'\x02\x06'
|
||||
if (len(self.layoutObject.info["libraryName"]) % 2 != 0):
|
||||
libraryName = self.layoutObject.info["libraryName"] + "\0"
|
||||
libraryName = self.layoutObject.info["libraryName"].encode() + "\0"
|
||||
else:
|
||||
libraryName = self.layoutObject.info["libraryName"]
|
||||
libraryName = self.layoutObject.info["libraryName"].encode()
|
||||
self.writeRecord(idBits+libraryName)
|
||||
## reference libraries
|
||||
if("referenceLibraries" in self.layoutObject.info):
|
||||
idBits='\x1F\x06'
|
||||
idBits=b'\x1F\x06'
|
||||
referenceLibraryA = self.layoutObject.info["referenceLibraries"][0]
|
||||
referenceLibraryB = self.layoutObject.info["referenceLibraries"][1]
|
||||
self.writeRecord(idBits+referenceLibraryA+referenceLibraryB)
|
||||
if("fonts" in self.layoutObject.info):
|
||||
idBits='\x20\x06'
|
||||
idBits=b'\x20\x06'
|
||||
fontA = self.layoutObject.info["fonts"][0]
|
||||
fontB = self.layoutObject.info["fonts"][1]
|
||||
fontC = self.layoutObject.info["fonts"][2]
|
||||
fontD = self.layoutObject.info["fonts"][3]
|
||||
self.writeRecord(idBits+fontA+fontB+fontC+fontD)
|
||||
if("attributeTable" in self.layoutObject.info):
|
||||
idBits='\x23\x06'
|
||||
idBits=b'\x23\x06'
|
||||
attributeTable = self.layoutObject.info["attributeTable"]
|
||||
self.writeRecord(idBits+attributeTable)
|
||||
if("generations" in self.layoutObject.info):
|
||||
idBits='\x22\x02'
|
||||
idBits=b'\x22\x02'
|
||||
generations = struct.pack(">h",self.layoutObject.info["generations"])
|
||||
self.writeRecord(idBits+generations)
|
||||
if("fileFormat" in self.layoutObject.info):
|
||||
idBits='\x36\x02'
|
||||
idBits=b'\x36\x02'
|
||||
fileFormat = struct.pack(">h",self.layoutObject.info["fileFormat"])
|
||||
self.writeRecord(idBits+fileFormat)
|
||||
if("mask" in self.layoutObject.info):
|
||||
idBits='\x37\x06'
|
||||
idBits=b'\x37\x06'
|
||||
mask = self.layoutObject.info["mask"]
|
||||
self.writeRecord(idBits+mask)
|
||||
if("units" in self.layoutObject.info):
|
||||
idBits='\x03\x05'
|
||||
idBits=b'\x03\x05'
|
||||
userUnits=self.ibmDataFromIeeeDouble(self.layoutObject.info["units"][0])
|
||||
dbUnits=self.ibmDataFromIeeeDouble((self.layoutObject.info["units"][0]*1e-6/self.layoutObject.info["units"][1])*self.layoutObject.info["units"][1])
|
||||
|
||||
|
|
@ -176,171 +176,171 @@ class Gds2writer:
|
|||
|
||||
self.writeRecord(idBits+userUnits+dbUnits)
|
||||
if(self.debugToTerminal==1):
|
||||
print "writer: userUnits %s"%(userUnits.encode("hex"))
|
||||
print "writer: dbUnits %s"%(dbUnits.encode("hex"))
|
||||
print("writer: userUnits %s"%(userUnits.encode("hex")))
|
||||
print("writer: dbUnits %s"%(dbUnits.encode("hex")))
|
||||
#self.ieeeFloatCheck(1.3e-6)
|
||||
|
||||
print "End of GDSII Header Written"
|
||||
print("End of GDSII Header Written")
|
||||
return 1
|
||||
|
||||
def writeBoundary(self,thisBoundary):
|
||||
idBits = '\x08\x00' #record Type
|
||||
idBits=b'\x08\x00' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisBoundary.elementFlags!=""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisBoundary.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisBoundary.plex!=""):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisBoundary.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisBoundary.drawingLayer!=""):
|
||||
idBits='\x0D\x02' #drawig layer
|
||||
idBits=b'\x0D\x02' #drawig layer
|
||||
drawingLayer = struct.pack(">h",thisBoundary.drawingLayer)
|
||||
self.writeRecord(idBits+drawingLayer)
|
||||
if(thisBoundary.purposeLayer):
|
||||
idBits='\x16\x02' #purpose layer
|
||||
idBits=b'\x16\x02' #purpose layer
|
||||
purposeLayer = struct.pack(">h",thisBoundary.purposeLayer)
|
||||
self.writeRecord(idBits+purposeLayer)
|
||||
if(thisBoundary.dataType!=""):
|
||||
idBits='\x0E\x02'#DataType
|
||||
idBits=b'\x0E\x02'#DataType
|
||||
dataType = struct.pack(">h",thisBoundary.dataType)
|
||||
self.writeRecord(idBits+dataType)
|
||||
if(thisBoundary.coordinates!=""):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
for coordinate in thisBoundary.coordinates:
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
y=struct.pack(">i",coordinate[1])
|
||||
x=struct.pack(">i",int(coordinate[0]))
|
||||
y=struct.pack(">i",int(coordinate[1]))
|
||||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
self.writeRecord(coordinateRecord)
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writePath(self,thisPath): #writes out a path structure
|
||||
idBits = '\x09\x00' #record Type
|
||||
idBits=b'\x09\x00' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisPath.elementFlags != ""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisPath.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisPath.plex!=""):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisPath.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisPath.drawingLayer):
|
||||
idBits='\x0D\x02' #drawig layer
|
||||
idBits=b'\x0D\x02' #drawig layer
|
||||
drawingLayer = struct.pack(">h",thisPath.drawingLayer)
|
||||
self.writeRecord(idBits+drawingLayer)
|
||||
if(thisPath.purposeLayer):
|
||||
idBits='\x16\x02' #purpose layer
|
||||
idBits=b'\x16\x02' #purpose layer
|
||||
purposeLayer = struct.pack(">h",thisPath.purposeLayer)
|
||||
self.writeRecord(idBits+purposeLayer)
|
||||
if(thisPath.pathType):
|
||||
idBits='\x21\x02' #Path type
|
||||
idBits=b'\x21\x02' #Path type
|
||||
pathType = struct.pack(">h",thisPath.pathType)
|
||||
self.writeRecord(idBits+pathType)
|
||||
if(thisPath.pathWidth):
|
||||
idBits='\x0F\x03'
|
||||
idBits=b'\x0F\x03'
|
||||
pathWidth = struct.pack(">i",thisPath.pathWidth)
|
||||
self.writeRecord(idBits+pathWidth)
|
||||
if(thisPath.coordinates):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
for coordinate in thisPath.coordinates:
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
y=struct.pack(">i",coordinate[1])
|
||||
x=struct.pack(">i",int(coordinate[0]))
|
||||
y=struct.pack(">i",int(coordinate[1]))
|
||||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
self.writeRecord(coordinateRecord)
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writeSref(self,thisSref): #reads in a reference to another structure
|
||||
idBits = '\x0A\x00' #record Type
|
||||
idBits=b'\x0A\x00' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisSref.elementFlags != ""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisSref.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisSref.plex!=""):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisSref.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisSref.sName!=""):
|
||||
idBits='\x12\x06'
|
||||
idBits=b'\x12\x06'
|
||||
if (len(thisSref.sName) % 2 != 0):
|
||||
sName = thisSref.sName+"\0"
|
||||
else:
|
||||
sName = thisSref.sName
|
||||
self.writeRecord(idBits+sName)
|
||||
self.writeRecord(idBits+sName.encode())
|
||||
if(thisSref.transFlags!=""):
|
||||
idBits='\x1A\x01'
|
||||
idBits=b'\x1A\x01'
|
||||
mirrorFlag = int(thisSref.transFlags[0])<<15
|
||||
rotateFlag = int(thisSref.transFlags[1])<<1
|
||||
magnifyFlag = int(thisSref.transFlags[2])<<3
|
||||
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
||||
self.writeRecord(idBits+transFlags)
|
||||
if(thisSref.magFactor!=""):
|
||||
idBits='\x1B\x05'
|
||||
idBits=b'\x1B\x05'
|
||||
magFactor=self.ibmDataFromIeeeDouble(thisSref.magFactor)
|
||||
self.writeRecord(idBits+magFactor)
|
||||
if(thisSref.rotateAngle!=""):
|
||||
idBits='\x1C\x05'
|
||||
idBits=b'\x1C\x05'
|
||||
rotateAngle=self.ibmDataFromIeeeDouble(thisSref.rotateAngle)
|
||||
self.writeRecord(idBits+rotateAngle)
|
||||
if(thisSref.coordinates!=""):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
coordinate = thisSref.coordinates
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
y=struct.pack(">i",coordinate[1])
|
||||
x=struct.pack(">i",int(coordinate[0]))
|
||||
y=struct.pack(">i",int(coordinate[1]))
|
||||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
#print thisSref.coordinates
|
||||
#print(thisSref.coordinates)
|
||||
self.writeRecord(coordinateRecord)
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writeAref(self,thisAref): #an array of references
|
||||
idBits = '\x0B\x00' #record Type
|
||||
idBits=b'\x0B\x00' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisAref.elementFlags!=""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisAref.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisAref.plex):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisAref.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisAref.aName):
|
||||
idBits='\x12\x06'
|
||||
idBits=b'\x12\x06'
|
||||
if (len(thisAref.aName) % 2 != 0):
|
||||
aName = thisAref.aName+"\0"
|
||||
else:
|
||||
aName = thisAref.aName
|
||||
self.writeRecord(idBits+aName)
|
||||
if(thisAref.transFlags):
|
||||
idBits='\x1A\x01'
|
||||
idBits=b'\x1A\x01'
|
||||
mirrorFlag = int(thisAref.transFlags[0])<<15
|
||||
rotateFlag = int(thisAref.transFlags[1])<<1
|
||||
magnifyFlag = int(thisAref.transFlags[0])<<3
|
||||
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
||||
self.writeRecord(idBits+transFlags)
|
||||
if(thisAref.magFactor):
|
||||
idBits='\x1B\x05'
|
||||
idBits=b'\x1B\x05'
|
||||
magFactor=self.ibmDataFromIeeeDouble(thisAref.magFactor)
|
||||
self.writeRecord(idBits+magFactor)
|
||||
if(thisAref.rotateAngle):
|
||||
idBits='\x1C\x05'
|
||||
idBits=b'\x1C\x05'
|
||||
rotateAngle=self.ibmDataFromIeeeDouble(thisAref.rotateAngle)
|
||||
self.writeRecord(idBits+rotateAngle)
|
||||
if(thisAref.coordinates):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
for coordinate in thisAref.coordinates:
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
|
|
@ -348,151 +348,151 @@ class Gds2writer:
|
|||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
self.writeRecord(coordinateRecord)
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writeText(self,thisText):
|
||||
idBits = '\x0C\x00' #record Type
|
||||
idBits=b'\x0C\x00' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisText.elementFlags!=""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisText.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisText.plex !=""):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisText.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisText.drawingLayer != ""):
|
||||
idBits='\x0D\x02' #drawing layer
|
||||
idBits=b'\x0D\x02' #drawing layer
|
||||
drawingLayer = struct.pack(">h",thisText.drawingLayer)
|
||||
self.writeRecord(idBits+drawingLayer)
|
||||
#if(thisText.purposeLayer):
|
||||
idBits='\x16\x02' #purpose layer
|
||||
idBits=b'\x16\x02' #purpose layer
|
||||
purposeLayer = struct.pack(">h",thisText.purposeLayer)
|
||||
self.writeRecord(idBits+purposeLayer)
|
||||
if(thisText.transFlags != ""):
|
||||
idBits='\x1A\x01'
|
||||
idBits=b'\x1A\x01'
|
||||
mirrorFlag = int(thisText.transFlags[0])<<15
|
||||
rotateFlag = int(thisText.transFlags[1])<<1
|
||||
magnifyFlag = int(thisText.transFlags[0])<<3
|
||||
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
||||
self.writeRecord(idBits+transFlags)
|
||||
if(thisText.magFactor != ""):
|
||||
idBits='\x1B\x05'
|
||||
idBits=b'\x1B\x05'
|
||||
magFactor=self.ibmDataFromIeeeDouble(thisText.magFactor)
|
||||
self.writeRecord(idBits+magFactor)
|
||||
if(thisText.rotateAngle != ""):
|
||||
idBits='\x1C\x05'
|
||||
idBits=b'\x1C\x05'
|
||||
rotateAngle=self.ibmDataFromIeeeDouble(thisText.rotateAngle)
|
||||
self.writeRecord(idBits+rotateAngle)
|
||||
if(thisText.pathType !=""):
|
||||
idBits='\x21\x02' #Path type
|
||||
idBits=b'\x21\x02' #Path type
|
||||
pathType = struct.pack(">h",thisText.pathType)
|
||||
self.writeRecord(idBits+pathType)
|
||||
if(thisText.pathWidth != ""):
|
||||
idBits='\x0F\x03'
|
||||
idBits=b'\x0F\x03'
|
||||
pathWidth = struct.pack(">i",thisText.pathWidth)
|
||||
self.writeRecord(idBits+pathWidth)
|
||||
if(thisText.presentationFlags!=""):
|
||||
idBits='\x1A\x01'
|
||||
idBits=b'\x1A\x01'
|
||||
font = thisText.presentationFlags[0]<<4
|
||||
verticalFlags = int(thisText.presentationFlags[1])<<2
|
||||
horizontalFlags = int(thisText.presentationFlags[2])
|
||||
presentationFlags = struct.pack(">H",font|verticalFlags|horizontalFlags)
|
||||
self.writeRecord(idBits+transFlags)
|
||||
if(thisText.coordinates!=""):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
for coordinate in thisText.coordinates:
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
y=struct.pack(">i",coordinate[1])
|
||||
x=struct.pack(">i",int(coordinate[0]))
|
||||
y=struct.pack(">i",int(coordinate[1]))
|
||||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
self.writeRecord(coordinateRecord)
|
||||
if(thisText.textString):
|
||||
idBits='\x19\x06'
|
||||
idBits=b'\x19\x06'
|
||||
textString = thisText.textString
|
||||
self.writeRecord(idBits+textString)
|
||||
self.writeRecord(idBits+textString.encode())
|
||||
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writeNode(self,thisNode):
|
||||
idBits = '\x15\x00' #record Type
|
||||
idBits=b'\x15\x00' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisNode.elementFlags!=""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisNode.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisNode.plex!=""):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisNode.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisNode.drawingLayer!=""):
|
||||
idBits='\x0D\x02' #drawig layer
|
||||
idBits=b'\x0D\x02' #drawig layer
|
||||
drawingLayer = struct.pack(">h",thisNode.drawingLayer)
|
||||
self.writeRecord(idBits+drawingLayer)
|
||||
if(thisNode.nodeType!=""):
|
||||
idBits='\x2A\x02'
|
||||
idBits=b'\x2A\x02'
|
||||
nodeType = struct.pack(">h",thisNode.nodeType)
|
||||
self.writeRecord(idBits+nodeType)
|
||||
if(thisText.coordinates!=""):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
for coordinate in thisText.coordinates:
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
y=struct.pack(">i",coordinate[1])
|
||||
x=struct.pack(">i",int(coordinate[0]))
|
||||
y=struct.pack(">i",int(coordinate[1]))
|
||||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writeBox(self,thisBox):
|
||||
idBits = '\x2E\x02' #record Type
|
||||
idBits=b'\x2E\x02' #record Type
|
||||
self.writeRecord(idBits)
|
||||
if(thisBox.elementFlags!=""):
|
||||
idBits='\x26\x01' #ELFLAGS
|
||||
idBits=b'\x26\x01' #ELFLAGS
|
||||
elementFlags = struct.pack(">h",thisBox.elementFlags)
|
||||
self.writeRecord(idBits+elementFlags)
|
||||
if(thisBox.plex!=""):
|
||||
idBits='\x2F\x03' #PLEX
|
||||
idBits=b'\x2F\x03' #PLEX
|
||||
plex = struct.pack(">i",thisBox.plex)
|
||||
self.writeRecord(idBits+plex)
|
||||
if(thisBox.drawingLayer!=""):
|
||||
idBits='\x0D\x02' #drawig layer
|
||||
idBits=b'\x0D\x02' #drawig layer
|
||||
drawingLayer = struct.pack(">h",thisBox.drawingLayer)
|
||||
self.writeRecord(idBits+drawingLayer)
|
||||
if(thisBox.purposeLayer):
|
||||
idBits='\x16\x02' #purpose layer
|
||||
idBits=b'\x16\x02' #purpose layer
|
||||
purposeLayer = struct.pack(">h",thisBox.purposeLayer)
|
||||
self.writeRecord(idBits+purposeLayer)
|
||||
if(thisBox.boxValue!=""):
|
||||
idBits='\x2D\x00'
|
||||
idBits=b'\x2D\x00'
|
||||
boxValue = struct.pack(">h",thisBox.boxValue)
|
||||
self.writeRecord(idBits+boxValue)
|
||||
if(thisBox.coordinates!=""):
|
||||
idBits='\x10\x03' #XY Data Points
|
||||
idBits=b'\x10\x03' #XY Data Points
|
||||
coordinateRecord = idBits
|
||||
for coordinate in thisBox.coordinates:
|
||||
x=struct.pack(">i",coordinate[0])
|
||||
y=struct.pack(">i",coordinate[1])
|
||||
x=struct.pack(">i",int(coordinate[0]))
|
||||
y=struct.pack(">i",int(coordinate[1]))
|
||||
coordinateRecord+=x
|
||||
coordinateRecord+=y
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
idBits='\x11\x00' #End Of Element
|
||||
idBits=b'\x11\x00' #End Of Element
|
||||
coordinateRecord = idBits
|
||||
self.writeRecord(coordinateRecord)
|
||||
|
||||
def writeNextStructure(self,structureName):
|
||||
#first put in the structure head
|
||||
thisStructure = self.layoutObject.structures[structureName]
|
||||
idBits='\x05\x02'
|
||||
idBits=b'\x05\x02'
|
||||
createYear = struct.pack(">h",thisStructure.createDate[0])
|
||||
createMonth = struct.pack(">h",thisStructure.createDate[1])
|
||||
createDay = struct.pack(">h",thisStructure.createDate[2])
|
||||
|
|
@ -508,12 +508,12 @@ class Gds2writer:
|
|||
self.writeRecord(idBits+createYear+createMonth+createDay+createHour+createMinute+createSecond\
|
||||
+modYear+modMonth+modDay+modHour+modMinute+modSecond)
|
||||
#now the structure name
|
||||
idBits='\x06\x06'
|
||||
idBits=b'\x06\x06'
|
||||
##caveat: the name needs to be an EVEN number of characters
|
||||
if(len(structureName)%2 == 1):
|
||||
#pad with a zero
|
||||
structureName = structureName + '\x00'
|
||||
self.writeRecord(idBits+structureName)
|
||||
self.writeRecord(idBits+structureName.encode())
|
||||
#now go through all the structure elements and write them in
|
||||
|
||||
for boundary in thisStructure.boundaries:
|
||||
|
|
@ -531,7 +531,7 @@ class Gds2writer:
|
|||
for box in thisStructure.boxes:
|
||||
self.writeBox(box)
|
||||
#put in the structure tail
|
||||
idBits='\x07\x00'
|
||||
idBits=b'\x07\x00'
|
||||
self.writeRecord(idBits)
|
||||
|
||||
def writeGds2(self):
|
||||
|
|
@ -540,7 +540,7 @@ class Gds2writer:
|
|||
for structureName in self.layoutObject.structures:
|
||||
self.writeNextStructure(structureName)
|
||||
#at the end, put in the END LIB record
|
||||
idBits='\x04\x00'
|
||||
idBits=b'\x04\x00'
|
||||
self.writeRecord(idBits)
|
||||
|
||||
def writeToFile(self,fileName):
|
||||
|
|
|
|||
|
|
@ -122,11 +122,11 @@ class GdsStreamer:
|
|||
#stream the gds out from cadence
|
||||
worker = os.popen("pipo strmout "+self.workingDirectory+"/partStreamOut.tmpl")
|
||||
#dump the outputs to the screen line by line
|
||||
print "Streaming Out From Cadence......"
|
||||
print("Streaming Out From Cadence......")
|
||||
while 1:
|
||||
line = worker.readline()
|
||||
if not line: break #this means sim is finished so jump out
|
||||
#else: print line #for debug only
|
||||
#else: print(line) #for debug only
|
||||
worker.close()
|
||||
#now remove the template file
|
||||
os.remove(self.workingDirectory+"/partStreamOut.tmpl")
|
||||
|
|
@ -142,13 +142,13 @@ class GdsStreamer:
|
|||
#stream the gds out from cadence
|
||||
worker = os.popen("pipo strmin "+self.workingDirectory+"/partStreamIn.tmpl")
|
||||
#dump the outputs to the screen line by line
|
||||
print "Streaming In To Cadence......"
|
||||
print("Streaming In To Cadence......")
|
||||
while 1:
|
||||
line = worker.readline()
|
||||
if not line: break #this means sim is finished so jump out
|
||||
#else: print line #for debug only
|
||||
#else: print(line) #for debug only
|
||||
worker.close()
|
||||
#now remove the template file
|
||||
os.remove(self.workingDirectory+"/partStreamIn.tmpl")
|
||||
#and go back to whever it was we started from
|
||||
os.chdir(currentPath)
|
||||
os.chdir(currentPath)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pyx
|
||||
import math
|
||||
import mpmath
|
||||
from numpy import matrix
|
||||
from gdsPrimitives import *
|
||||
import random
|
||||
|
||||
|
|
@ -39,12 +39,12 @@ class pdfLayout:
|
|||
"""
|
||||
xyCoordinates = []
|
||||
#setup a translation matrix
|
||||
tMatrix = mpmath.matrix([[1.0,0.0,origin[0]],[0.0,1.0,origin[1]],[0.0,0.0,1.0]])
|
||||
tMatrix = matrix([[1.0,0.0,origin[0]],[0.0,1.0,origin[1]],[0.0,0.0,1.0]])
|
||||
#and a rotation matrix
|
||||
rMatrix = mpmath.matrix([[uVector[0],vVector[0],0.0],[uVector[1],vVector[1],0.0],[0.0,0.0,1.0]])
|
||||
rMatrix = matrix([[uVector[0],vVector[0],0.0],[uVector[1],vVector[1],0.0],[0.0,0.0,1.0]])
|
||||
for coordinate in uvCoordinates:
|
||||
#grab the point in UV space
|
||||
uvPoint = mpmath.matrix([coordinate[0],coordinate[1],1.0])
|
||||
uvPoint = matrix([coordinate[0],coordinate[1],1.0])
|
||||
#now rotate and translate it back to XY space
|
||||
xyPoint = rMatrix * uvPoint
|
||||
xyPoint = tMatrix * xyPoint
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from gdsPrimitives import *
|
||||
from .gdsPrimitives import *
|
||||
from datetime import *
|
||||
import mpmath
|
||||
import gdsPrimitives
|
||||
#from mpmath import matrix
|
||||
from numpy import matrix
|
||||
#import gdsPrimitives
|
||||
import debug
|
||||
|
||||
class VlsiLayout:
|
||||
|
|
@ -10,7 +11,7 @@ class VlsiLayout:
|
|||
def __init__(self, name=None, units=(0.001,1e-9), libraryName = "DEFAULT.DB", gdsVersion=5):
|
||||
#keep a list of all the structures in this layout
|
||||
self.units = units
|
||||
#print units
|
||||
#print(units)
|
||||
modDate = datetime.now()
|
||||
self.structures=dict()
|
||||
self.layerNumbersInUse = []
|
||||
|
|
@ -89,7 +90,7 @@ class VlsiLayout:
|
|||
|
||||
def newLayout(self,newName):
|
||||
#if (newName == "" | newName == 0):
|
||||
# print("ERROR: vlsiLayout.py:newLayout newName is null")
|
||||
# print("ERROR: vlsiLayout.py:newLayout newName is null")
|
||||
|
||||
#make sure the newName is a multiple of 2 characters
|
||||
#if(len(newName)%2 == 1):
|
||||
|
|
@ -134,13 +135,12 @@ class VlsiLayout:
|
|||
self.populateCoordinateMap()
|
||||
|
||||
def deduceHierarchy(self):
|
||||
#first, find the root of the tree.
|
||||
#go through and get the name of every structure.
|
||||
#then, go through and find which structure is not
|
||||
#contained by any other structure. this is the root.
|
||||
""" First, find the root of the tree.
|
||||
Then go through and get the name of every structure.
|
||||
Then, go through and find which structure is not
|
||||
contained by any other structure. this is the root."""
|
||||
structureNames=[]
|
||||
for name in self.structures:
|
||||
#print "deduceHierarchy: structure.name[%s]",name //FIXME: Added By Tom G.
|
||||
structureNames+=[name]
|
||||
|
||||
for name in self.structures:
|
||||
|
|
@ -148,7 +148,7 @@ class VlsiLayout:
|
|||
for sref in self.structures[name].srefs: #go through each reference
|
||||
if sref.sName in structureNames: #and compare to our list
|
||||
structureNames.remove(sref.sName)
|
||||
|
||||
|
||||
self.rootStructureName = structureNames[0]
|
||||
|
||||
def traverseTheHierarchy(self, startingStructureName=None, delegateFunction = None,
|
||||
|
|
@ -163,19 +163,20 @@ class VlsiLayout:
|
|||
rotateAngle = 0
|
||||
else:
|
||||
rotateAngle = math.radians(float(rotateAngle))
|
||||
mRotate = mpmath.matrix([[math.cos(rotateAngle),-math.sin(rotateAngle),0.0],
|
||||
[math.sin(rotateAngle),math.cos(rotateAngle),0.0],[0.0,0.0,1.0],])
|
||||
mRotate = matrix([[math.cos(rotateAngle),-math.sin(rotateAngle),0.0],
|
||||
[math.sin(rotateAngle),math.cos(rotateAngle),0.0],
|
||||
[0.0,0.0,1.0]])
|
||||
#set up the translation matrix
|
||||
translateX = float(coordinates[0])
|
||||
translateY = float(coordinates[1])
|
||||
mTranslate = mpmath.matrix([[1.0,0.0,translateX],[0.0,1.0,translateY],[0.0,0.0,1.0]])
|
||||
mTranslate = matrix([[1.0,0.0,translateX],[0.0,1.0,translateY],[0.0,0.0,1.0]])
|
||||
#set up the scale matrix (handles mirror X)
|
||||
scaleX = 1.0
|
||||
if(transFlags[0]):
|
||||
scaleY = -1.0
|
||||
else:
|
||||
scaleY = 1.0
|
||||
mScale = mpmath.matrix([[scaleX,0.0,0.0],[0.0,scaleY,0.0],[0.0,0.0,1.0]])
|
||||
mScale = matrix([[scaleX,0.0,0.0],[0.0,scaleY,0.0],[0.0,0.0,1.0]])
|
||||
|
||||
#we need to keep track of all transforms in the hierarchy
|
||||
#when we add an element to the xy tree, we apply all transforms from the bottom up
|
||||
|
|
@ -197,7 +198,7 @@ class VlsiLayout:
|
|||
transFlags = sref.transFlags,
|
||||
coordinates = sref.coordinates)
|
||||
# else:
|
||||
# print "WARNING: via encountered, ignoring:", sref.sName
|
||||
# print("WARNING: via encountered, ignoring:", sref.sName)
|
||||
#MUST HANDLE AREFs HERE AS WELL
|
||||
#when we return, drop the last transform from the transformPath
|
||||
del transformPath[-1]
|
||||
|
|
@ -210,10 +211,10 @@ class VlsiLayout:
|
|||
|
||||
def populateCoordinateMap(self):
|
||||
def addToXyTree(startingStructureName = None,transformPath = None):
|
||||
#print"populateCoordinateMap"
|
||||
uVector = mpmath.matrix([1.0,0.0,0.0]) #start with normal basis vectors
|
||||
vVector = mpmath.matrix([0.0,1.0,0.0])
|
||||
origin = mpmath.matrix([0.0,0.0,1.0]) #and an origin (Z component is 1.0 to indicate position instead of vector)
|
||||
#print("populateCoordinateMap")
|
||||
uVector = matrix([1.0,0.0,0.0]).transpose() #start with normal basis vectors
|
||||
vVector = matrix([0.0,1.0,0.0]).transpose()
|
||||
origin = matrix([0.0,0.0,1.0]).transpose() #and an origin (Z component is 1.0 to indicate position instead of vector)
|
||||
#make a copy of all the transforms and reverse it
|
||||
reverseTransformPath = transformPath[:]
|
||||
if len(reverseTransformPath) > 1:
|
||||
|
|
@ -245,7 +246,7 @@ class VlsiLayout:
|
|||
#userUnitsPerMicron = userUnit / 1e-6
|
||||
userUnitsPerMicron = userUnit / (userUnit)
|
||||
layoutUnitsPerMicron = userUnitsPerMicron / self.units[0]
|
||||
#print "userUnit:",userUnit,"userUnitsPerMicron",userUnitsPerMicron,"layoutUnitsPerMicron",layoutUnitsPerMicron,[microns,microns*layoutUnitsPerMicron]
|
||||
#print("userUnit:",userUnit,"userUnitsPerMicron",userUnitsPerMicron,"layoutUnitsPerMicron",layoutUnitsPerMicron,[microns,microns*layoutUnitsPerMicron])
|
||||
return round(microns*layoutUnitsPerMicron,0)
|
||||
|
||||
def changeRoot(self,newRoot, create=False):
|
||||
|
|
@ -259,7 +260,7 @@ class VlsiLayout:
|
|||
# Determine if newRoot exists
|
||||
# layoutToAdd (default) or nameOfLayout
|
||||
if (newRoot == 0 | ((newRoot not in self.structures) & ~create)):
|
||||
print "ERROR: vlsiLayout.changeRoot: Name of new root [%s] not found and create flag is false"%newRoot
|
||||
print("ERROR: vlsiLayout.changeRoot: Name of new root [%s] not found and create flag is false"%newRoot)
|
||||
exit(1)
|
||||
else:
|
||||
if ((newRoot not in self.structures) & create):
|
||||
|
|
@ -308,13 +309,13 @@ class VlsiLayout:
|
|||
self.layerNumbersInUse += [layerNumber]
|
||||
#Also, check if the user units / microns is the same as this Layout
|
||||
#if (layoutToAdd.units != self.units):
|
||||
#print "WARNING: VlsiLayout: Units from design to be added do not match target Layout"
|
||||
#print("WARNING: VlsiLayout: Units from design to be added do not match target Layout")
|
||||
|
||||
# if debug: print "DEBUG: vlsilayout: Using %d layers"
|
||||
# if debug: print("DEBUG: vlsilayout: Using %d layers")
|
||||
|
||||
# If we can't find the structure, error
|
||||
#if StructureFound == False:
|
||||
#print "ERROR: vlsiLayout.addInstance: [%s] Name not found in local structures, "%(nameOfLayout)
|
||||
#print("ERROR: vlsiLayout.addInstance: [%s] Name not found in local structures, "%(nameOfLayout))
|
||||
#return #FIXME: remove!
|
||||
#exit(1)
|
||||
|
||||
|
|
@ -353,10 +354,10 @@ class VlsiLayout:
|
|||
Method to add a box to a layout
|
||||
"""
|
||||
offsetInLayoutUnits = (self.userUnits(offsetInMicrons[0]),self.userUnits(offsetInMicrons[1]))
|
||||
#print "addBox:offsetInLayoutUnits",offsetInLayoutUnits
|
||||
#print("addBox:offsetInLayoutUnits",offsetInLayoutUnits)
|
||||
widthInLayoutUnits = self.userUnits(width)
|
||||
heightInLayoutUnits = self.userUnits(height)
|
||||
#print "offsetInLayoutUnits",widthInLayoutUnits,"heightInLayoutUnits",heightInLayoutUnits
|
||||
#print("offsetInLayoutUnits",widthInLayoutUnits,"heightInLayoutUnits",heightInLayoutUnits)
|
||||
if not center:
|
||||
coordinates=[offsetInLayoutUnits,
|
||||
(offsetInLayoutUnits[0]+widthInLayoutUnits,offsetInLayoutUnits[1]),
|
||||
|
|
@ -522,7 +523,7 @@ class VlsiLayout:
|
|||
heightInBlocks = int(coverageHeight/effectiveBlock)
|
||||
passFailRecord = []
|
||||
|
||||
print "Filling layer:",layerToFill
|
||||
print("Filling layer:",layerToFill)
|
||||
def isThisBlockOk(startingStructureName,coordinates,rotateAngle=None):
|
||||
#go through every boundary and check
|
||||
for boundary in self.structures[startingStructureName].boundaries:
|
||||
|
|
@ -568,7 +569,7 @@ class VlsiLayout:
|
|||
#if its bad, this global tempPassFail will be false
|
||||
#if true, we can add the block
|
||||
passFailRecord+=[self.tempPassFail]
|
||||
print "Percent Complete:"+str(percentDone)
|
||||
print("Percent Complete:"+str(percentDone))
|
||||
|
||||
|
||||
passFailIndex=0
|
||||
|
|
@ -579,7 +580,7 @@ class VlsiLayout:
|
|||
if passFailRecord[passFailIndex]:
|
||||
self.addBox(layerToFill, (blockX,blockY), width=blockSize, height=blockSize)
|
||||
passFailIndex+=1
|
||||
print "Done\n\n"
|
||||
print("Done\n\n")
|
||||
|
||||
def getLayoutBorder(self,borderlayer):
|
||||
for boundary in self.structures[self.rootStructureName].boundaries:
|
||||
|
|
@ -591,7 +592,7 @@ class VlsiLayout:
|
|||
cellSize=[right_top[0]-left_bottom[0],right_top[1]-left_bottom[1]]
|
||||
cellSizeMicron=[cellSize[0]*self.units[0],cellSize[1]*self.units[0]]
|
||||
if not(cellSizeMicron):
|
||||
print "Error: "+str(self.rootStructureName)+".cell_size information not found yet"
|
||||
print("Error: "+str(self.rootStructureName)+".cell_size information not found yet")
|
||||
return cellSizeMicron
|
||||
|
||||
def measureSize(self,startStructure):
|
||||
|
|
@ -700,7 +701,7 @@ class VlsiLayout:
|
|||
debug.warning("Did not find pin on layer {0} at coordinate {1}".format(layer, coordinate))
|
||||
|
||||
# sort the boundaries, return the max area pin boundary
|
||||
pin_boundaries.sort(cmpBoundaryAreas,reverse=True)
|
||||
pin_boundaries.sort(key=boundaryArea,reverse=True)
|
||||
pin_boundary=pin_boundaries[0]
|
||||
|
||||
# Convert to USER units
|
||||
|
|
@ -743,7 +744,8 @@ class VlsiLayout:
|
|||
shape_list=[]
|
||||
for label in label_list:
|
||||
(label_coordinate,label_layer)=label
|
||||
shape_list.append(self.getPinShapeByDBLocLayer(label_coordinate, label_layer))
|
||||
shape = self.getPinShapeByDBLocLayer(label_coordinate, label_layer)
|
||||
shape_list.append(shape)
|
||||
return shape_list
|
||||
|
||||
def getAllPinShapesByLabel(self,label_name):
|
||||
|
|
@ -797,23 +799,23 @@ class VlsiLayout:
|
|||
# Rectangle is [leftx, bottomy, rightx, topy].
|
||||
boundaryRect=[left_bottom[0],left_bottom[1],right_top[0],right_top[1]]
|
||||
boundaryRect=self.transformRectangle(boundaryRect,structureuVector,structurevVector)
|
||||
boundaryRect=[boundaryRect[0]+structureOrigin[0],boundaryRect[1]+structureOrigin[1],
|
||||
boundaryRect[2]+structureOrigin[0],boundaryRect[3]+structureOrigin[1]]
|
||||
|
||||
boundaryRect=[boundaryRect[0]+structureOrigin[0].item(),boundaryRect[1]+structureOrigin[1].item(),
|
||||
boundaryRect[2]+structureOrigin[0].item(),boundaryRect[3]+structureOrigin[1].item()]
|
||||
|
||||
if self.labelInRectangle(coordinates,boundaryRect):
|
||||
boundaries.append(boundaryRect)
|
||||
|
||||
return boundaries
|
||||
|
||||
def transformRectangle(self,orignalRectangle,uVector,vVector):
|
||||
def transformRectangle(self,originalRectangle,uVector,vVector):
|
||||
"""
|
||||
Transforms the four coordinates of a rectangle in space
|
||||
and recomputes the left, bottom, right, top values.
|
||||
"""
|
||||
leftBottom=mpmath.matrix([orignalRectangle[0],orignalRectangle[1]])
|
||||
leftBottom=[originalRectangle[0],originalRectangle[1]]
|
||||
leftBottom=self.transformCoordinate(leftBottom,uVector,vVector)
|
||||
|
||||
rightTop=mpmath.matrix([orignalRectangle[2],orignalRectangle[3]])
|
||||
rightTop=[originalRectangle[2],originalRectangle[3]]
|
||||
rightTop=self.transformCoordinate(rightTop,uVector,vVector)
|
||||
|
||||
left=min(leftBottom[0],rightTop[0])
|
||||
|
|
@ -821,14 +823,15 @@ class VlsiLayout:
|
|||
right=max(leftBottom[0],rightTop[0])
|
||||
top=max(leftBottom[1],rightTop[1])
|
||||
|
||||
return [left,bottom,right,top]
|
||||
newRectangle = [left,bottom,right,top]
|
||||
return newRectangle
|
||||
|
||||
def transformCoordinate(self,coordinate,uVector,vVector):
|
||||
"""
|
||||
Rotate a coordinate in space.
|
||||
"""
|
||||
x=coordinate[0]*uVector[0]+coordinate[1]*uVector[1]
|
||||
y=coordinate[1]*vVector[1]+coordinate[0]*vVector[0]
|
||||
x=coordinate[0]*uVector[0].item()+coordinate[1]*uVector[1].item()
|
||||
y=coordinate[1]*vVector[1].item()+coordinate[0]*vVector[0].item()
|
||||
transformCoordinate=[x,y]
|
||||
|
||||
return transformCoordinate
|
||||
|
|
@ -845,18 +848,12 @@ class VlsiLayout:
|
|||
else:
|
||||
return False
|
||||
|
||||
def cmpBoundaryAreas(A,B):
|
||||
def boundaryArea(A):
|
||||
"""
|
||||
Compares two rectangles and return true if Area(A)>Area(B).
|
||||
Returns boundary area for sorting.
|
||||
"""
|
||||
area_A=(A[2]-A[0])*(A[3]-A[1])
|
||||
area_B=(B[2]-B[0])*(B[3]-B[1])
|
||||
if area_A>area_B:
|
||||
return 1
|
||||
elif area_A==area_B:
|
||||
return 0
|
||||
else:
|
||||
return -1
|
||||
return area_A
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,386 +0,0 @@
|
|||
__version__ = '0.14'
|
||||
|
||||
from usertools import monitor, timing
|
||||
|
||||
from ctx_fp import FPContext
|
||||
from ctx_mp import MPContext
|
||||
|
||||
fp = FPContext()
|
||||
mp = MPContext()
|
||||
|
||||
fp._mp = mp
|
||||
mp._mp = mp
|
||||
mp._fp = fp
|
||||
fp._fp = fp
|
||||
|
||||
# XXX: extremely bad pickle hack
|
||||
import ctx_mp as _ctx_mp
|
||||
_ctx_mp._mpf_module.mpf = mp.mpf
|
||||
_ctx_mp._mpf_module.mpc = mp.mpc
|
||||
|
||||
make_mpf = mp.make_mpf
|
||||
make_mpc = mp.make_mpc
|
||||
|
||||
extraprec = mp.extraprec
|
||||
extradps = mp.extradps
|
||||
workprec = mp.workprec
|
||||
workdps = mp.workdps
|
||||
|
||||
mag = mp.mag
|
||||
|
||||
bernfrac = mp.bernfrac
|
||||
|
||||
jdn = mp.jdn
|
||||
jsn = mp.jsn
|
||||
jcn = mp.jcn
|
||||
jtheta = mp.jtheta
|
||||
calculate_nome = mp.calculate_nome
|
||||
|
||||
nint_distance = mp.nint_distance
|
||||
|
||||
plot = mp.plot
|
||||
cplot = mp.cplot
|
||||
splot = mp.splot
|
||||
|
||||
odefun = mp.odefun
|
||||
|
||||
jacobian = mp.jacobian
|
||||
findroot = mp.findroot
|
||||
multiplicity = mp.multiplicity
|
||||
|
||||
isinf = mp.isinf
|
||||
isnan = mp.isnan
|
||||
isint = mp.isint
|
||||
almosteq = mp.almosteq
|
||||
nan = mp.nan
|
||||
rand = mp.rand
|
||||
|
||||
absmin = mp.absmin
|
||||
absmax = mp.absmax
|
||||
|
||||
fraction = mp.fraction
|
||||
|
||||
linspace = mp.linspace
|
||||
arange = mp.arange
|
||||
|
||||
mpmathify = convert = mp.convert
|
||||
mpc = mp.mpc
|
||||
mpi = mp.mpi
|
||||
|
||||
nstr = mp.nstr
|
||||
nprint = mp.nprint
|
||||
chop = mp.chop
|
||||
|
||||
fneg = mp.fneg
|
||||
fadd = mp.fadd
|
||||
fsub = mp.fsub
|
||||
fmul = mp.fmul
|
||||
fdiv = mp.fdiv
|
||||
fprod = mp.fprod
|
||||
|
||||
quad = mp.quad
|
||||
quadgl = mp.quadgl
|
||||
quadts = mp.quadts
|
||||
quadosc = mp.quadosc
|
||||
|
||||
pslq = mp.pslq
|
||||
identify = mp.identify
|
||||
findpoly = mp.findpoly
|
||||
|
||||
richardson = mp.richardson
|
||||
shanks = mp.shanks
|
||||
nsum = mp.nsum
|
||||
nprod = mp.nprod
|
||||
diff = mp.diff
|
||||
diffs = mp.diffs
|
||||
diffun = mp.diffun
|
||||
differint = mp.differint
|
||||
taylor = mp.taylor
|
||||
pade = mp.pade
|
||||
polyval = mp.polyval
|
||||
polyroots = mp.polyroots
|
||||
fourier = mp.fourier
|
||||
fourierval = mp.fourierval
|
||||
sumem = mp.sumem
|
||||
chebyfit = mp.chebyfit
|
||||
limit = mp.limit
|
||||
|
||||
matrix = mp.matrix
|
||||
eye = mp.eye
|
||||
diag = mp.diag
|
||||
zeros = mp.zeros
|
||||
ones = mp.ones
|
||||
hilbert = mp.hilbert
|
||||
randmatrix = mp.randmatrix
|
||||
swap_row = mp.swap_row
|
||||
extend = mp.extend
|
||||
norm = mp.norm
|
||||
mnorm = mp.mnorm
|
||||
|
||||
lu_solve = mp.lu_solve
|
||||
lu = mp.lu
|
||||
unitvector = mp.unitvector
|
||||
inverse = mp.inverse
|
||||
residual = mp.residual
|
||||
qr_solve = mp.qr_solve
|
||||
cholesky = mp.cholesky
|
||||
cholesky_solve = mp.cholesky_solve
|
||||
det = mp.det
|
||||
cond = mp.cond
|
||||
|
||||
expm = mp.expm
|
||||
sqrtm = mp.sqrtm
|
||||
powm = mp.powm
|
||||
logm = mp.logm
|
||||
sinm = mp.sinm
|
||||
cosm = mp.cosm
|
||||
|
||||
mpf = mp.mpf
|
||||
j = mp.j
|
||||
exp = mp.exp
|
||||
expj = mp.expj
|
||||
expjpi = mp.expjpi
|
||||
ln = mp.ln
|
||||
im = mp.im
|
||||
re = mp.re
|
||||
inf = mp.inf
|
||||
ninf = mp.ninf
|
||||
sign = mp.sign
|
||||
|
||||
eps = mp.eps
|
||||
pi = mp.pi
|
||||
ln2 = mp.ln2
|
||||
ln10 = mp.ln10
|
||||
phi = mp.phi
|
||||
e = mp.e
|
||||
euler = mp.euler
|
||||
catalan = mp.catalan
|
||||
khinchin = mp.khinchin
|
||||
glaisher = mp.glaisher
|
||||
apery = mp.apery
|
||||
degree = mp.degree
|
||||
twinprime = mp.twinprime
|
||||
mertens = mp.mertens
|
||||
|
||||
ldexp = mp.ldexp
|
||||
frexp = mp.frexp
|
||||
|
||||
fsum = mp.fsum
|
||||
fdot = mp.fdot
|
||||
|
||||
sqrt = mp.sqrt
|
||||
cbrt = mp.cbrt
|
||||
exp = mp.exp
|
||||
ln = mp.ln
|
||||
log = mp.log
|
||||
log10 = mp.log10
|
||||
power = mp.power
|
||||
cos = mp.cos
|
||||
sin = mp.sin
|
||||
tan = mp.tan
|
||||
cosh = mp.cosh
|
||||
sinh = mp.sinh
|
||||
tanh = mp.tanh
|
||||
acos = mp.acos
|
||||
asin = mp.asin
|
||||
atan = mp.atan
|
||||
asinh = mp.asinh
|
||||
acosh = mp.acosh
|
||||
atanh = mp.atanh
|
||||
sec = mp.sec
|
||||
csc = mp.csc
|
||||
cot = mp.cot
|
||||
sech = mp.sech
|
||||
csch = mp.csch
|
||||
coth = mp.coth
|
||||
asec = mp.asec
|
||||
acsc = mp.acsc
|
||||
acot = mp.acot
|
||||
asech = mp.asech
|
||||
acsch = mp.acsch
|
||||
acoth = mp.acoth
|
||||
cospi = mp.cospi
|
||||
sinpi = mp.sinpi
|
||||
sinc = mp.sinc
|
||||
sincpi = mp.sincpi
|
||||
fabs = mp.fabs
|
||||
re = mp.re
|
||||
im = mp.im
|
||||
conj = mp.conj
|
||||
floor = mp.floor
|
||||
ceil = mp.ceil
|
||||
root = mp.root
|
||||
nthroot = mp.nthroot
|
||||
hypot = mp.hypot
|
||||
modf = mp.modf
|
||||
ldexp = mp.ldexp
|
||||
frexp = mp.frexp
|
||||
sign = mp.sign
|
||||
arg = mp.arg
|
||||
phase = mp.phase
|
||||
polar = mp.polar
|
||||
rect = mp.rect
|
||||
degrees = mp.degrees
|
||||
radians = mp.radians
|
||||
atan2 = mp.atan2
|
||||
fib = mp.fib
|
||||
fibonacci = mp.fibonacci
|
||||
lambertw = mp.lambertw
|
||||
zeta = mp.zeta
|
||||
altzeta = mp.altzeta
|
||||
gamma = mp.gamma
|
||||
factorial = mp.factorial
|
||||
fac = mp.fac
|
||||
fac2 = mp.fac2
|
||||
beta = mp.beta
|
||||
betainc = mp.betainc
|
||||
psi = mp.psi
|
||||
#psi0 = mp.psi0
|
||||
#psi1 = mp.psi1
|
||||
#psi2 = mp.psi2
|
||||
#psi3 = mp.psi3
|
||||
polygamma = mp.polygamma
|
||||
digamma = mp.digamma
|
||||
#trigamma = mp.trigamma
|
||||
#tetragamma = mp.tetragamma
|
||||
#pentagamma = mp.pentagamma
|
||||
harmonic = mp.harmonic
|
||||
bernoulli = mp.bernoulli
|
||||
bernfrac = mp.bernfrac
|
||||
stieltjes = mp.stieltjes
|
||||
hurwitz = mp.hurwitz
|
||||
dirichlet = mp.dirichlet
|
||||
bernpoly = mp.bernpoly
|
||||
eulerpoly = mp.eulerpoly
|
||||
eulernum = mp.eulernum
|
||||
polylog = mp.polylog
|
||||
clsin = mp.clsin
|
||||
clcos = mp.clcos
|
||||
gammainc = mp.gammainc
|
||||
gammaprod = mp.gammaprod
|
||||
binomial = mp.binomial
|
||||
rf = mp.rf
|
||||
ff = mp.ff
|
||||
hyper = mp.hyper
|
||||
hyp0f1 = mp.hyp0f1
|
||||
hyp1f1 = mp.hyp1f1
|
||||
hyp1f2 = mp.hyp1f2
|
||||
hyp2f1 = mp.hyp2f1
|
||||
hyp2f2 = mp.hyp2f2
|
||||
hyp2f0 = mp.hyp2f0
|
||||
hyp2f3 = mp.hyp2f3
|
||||
hyp3f2 = mp.hyp3f2
|
||||
hyperu = mp.hyperu
|
||||
hypercomb = mp.hypercomb
|
||||
meijerg = mp.meijerg
|
||||
appellf1 = mp.appellf1
|
||||
erf = mp.erf
|
||||
erfc = mp.erfc
|
||||
erfi = mp.erfi
|
||||
erfinv = mp.erfinv
|
||||
npdf = mp.npdf
|
||||
ncdf = mp.ncdf
|
||||
expint = mp.expint
|
||||
e1 = mp.e1
|
||||
ei = mp.ei
|
||||
li = mp.li
|
||||
ci = mp.ci
|
||||
si = mp.si
|
||||
chi = mp.chi
|
||||
shi = mp.shi
|
||||
fresnels = mp.fresnels
|
||||
fresnelc = mp.fresnelc
|
||||
airyai = mp.airyai
|
||||
airybi = mp.airybi
|
||||
ellipe = mp.ellipe
|
||||
ellipk = mp.ellipk
|
||||
agm = mp.agm
|
||||
jacobi = mp.jacobi
|
||||
chebyt = mp.chebyt
|
||||
chebyu = mp.chebyu
|
||||
legendre = mp.legendre
|
||||
legenp = mp.legenp
|
||||
legenq = mp.legenq
|
||||
hermite = mp.hermite
|
||||
gegenbauer = mp.gegenbauer
|
||||
laguerre = mp.laguerre
|
||||
spherharm = mp.spherharm
|
||||
besselj = mp.besselj
|
||||
j0 = mp.j0
|
||||
j1 = mp.j1
|
||||
besseli = mp.besseli
|
||||
bessely = mp.bessely
|
||||
besselk = mp.besselk
|
||||
hankel1 = mp.hankel1
|
||||
hankel2 = mp.hankel2
|
||||
struveh = mp.struveh
|
||||
struvel = mp.struvel
|
||||
whitm = mp.whitm
|
||||
whitw = mp.whitw
|
||||
ber = mp.ber
|
||||
bei = mp.bei
|
||||
ker = mp.ker
|
||||
kei = mp.kei
|
||||
coulombc = mp.coulombc
|
||||
coulombf = mp.coulombf
|
||||
coulombg = mp.coulombg
|
||||
lambertw = mp.lambertw
|
||||
barnesg = mp.barnesg
|
||||
superfac = mp.superfac
|
||||
hyperfac = mp.hyperfac
|
||||
loggamma = mp.loggamma
|
||||
siegeltheta = mp.siegeltheta
|
||||
siegelz = mp.siegelz
|
||||
grampoint = mp.grampoint
|
||||
zetazero = mp.zetazero
|
||||
riemannr = mp.riemannr
|
||||
primepi = mp.primepi
|
||||
primepi2 = mp.primepi2
|
||||
primezeta = mp.primezeta
|
||||
bell = mp.bell
|
||||
polyexp = mp.polyexp
|
||||
expm1 = mp.expm1
|
||||
powm1 = mp.powm1
|
||||
unitroots = mp.unitroots
|
||||
cyclotomic = mp.cyclotomic
|
||||
|
||||
|
||||
# be careful when changing this name, don't use test*!
|
||||
def runtests():
|
||||
"""
|
||||
Run all mpmath tests and print output.
|
||||
"""
|
||||
import os.path
|
||||
from inspect import getsourcefile
|
||||
import tests.runtests as tests
|
||||
testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
|
||||
importdir = os.path.abspath(testdir + '/../..')
|
||||
tests.testit(importdir, testdir)
|
||||
|
||||
def doctests():
|
||||
try:
|
||||
import psyco; psyco.full()
|
||||
except ImportError:
|
||||
pass
|
||||
import sys
|
||||
from timeit import default_timer as clock
|
||||
filter = []
|
||||
for i, arg in enumerate(sys.argv):
|
||||
if '__init__.py' in arg:
|
||||
filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
|
||||
break
|
||||
import doctest
|
||||
globs = globals().copy()
|
||||
for obj in globs: #sorted(globs.keys()):
|
||||
if filter:
|
||||
if not sum([pat in obj for pat in filter]):
|
||||
continue
|
||||
print obj,
|
||||
t1 = clock()
|
||||
doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
|
||||
t2 = clock()
|
||||
print round(t2-t1, 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
doctests()
|
||||
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
import calculus
|
||||
# XXX: hack to set methods
|
||||
import approximation
|
||||
import differentiation
|
||||
import extrapolation
|
||||
import polynomials
|
||||
|
|
@ -1,246 +0,0 @@
|
|||
from calculus import defun
|
||||
|
||||
#----------------------------------------------------------------------------#
|
||||
# Approximation methods #
|
||||
#----------------------------------------------------------------------------#
|
||||
|
||||
# The Chebyshev approximation formula is given at:
|
||||
# http://mathworld.wolfram.com/ChebyshevApproximationFormula.html
|
||||
|
||||
# The only major changes in the following code is that we return the
|
||||
# expanded polynomial coefficients instead of Chebyshev coefficients,
|
||||
# and that we automatically transform [a,b] -> [-1,1] and back
|
||||
# for convenience.
|
||||
|
||||
# Coefficient in Chebyshev approximation
|
||||
def chebcoeff(ctx,f,a,b,j,N):
|
||||
s = ctx.mpf(0)
|
||||
h = ctx.mpf(0.5)
|
||||
for k in range(1, N+1):
|
||||
t = ctx.cos(ctx.pi*(k-h)/N)
|
||||
s += f(t*(b-a)*h + (b+a)*h) * ctx.cos(ctx.pi*j*(k-h)/N)
|
||||
return 2*s/N
|
||||
|
||||
# Generate Chebyshev polynomials T_n(ax+b) in expanded form
|
||||
def chebT(ctx, a=1, b=0):
|
||||
Tb = [1]
|
||||
yield Tb
|
||||
Ta = [b, a]
|
||||
while 1:
|
||||
yield Ta
|
||||
# Recurrence: T[n+1](ax+b) = 2*(ax+b)*T[n](ax+b) - T[n-1](ax+b)
|
||||
Tmp = [0] + [2*a*t for t in Ta]
|
||||
for i, c in enumerate(Ta): Tmp[i] += 2*b*c
|
||||
for i, c in enumerate(Tb): Tmp[i] -= c
|
||||
Ta, Tb = Tmp, Ta
|
||||
|
||||
@defun
|
||||
def chebyfit(ctx, f, interval, N, error=False):
|
||||
r"""
|
||||
Computes a polynomial of degree `N-1` that approximates the
|
||||
given function `f` on the interval `[a, b]`. With ``error=True``,
|
||||
:func:`chebyfit` also returns an accurate estimate of the
|
||||
maximum absolute error; that is, the maximum value of
|
||||
`|f(x) - P(x)|` for `x \in [a, b]`.
|
||||
|
||||
:func:`chebyfit` uses the Chebyshev approximation formula,
|
||||
which gives a nearly optimal solution: that is, the maximum
|
||||
error of the approximating polynomial is very close to
|
||||
the smallest possible for any polynomial of the same degree.
|
||||
|
||||
Chebyshev approximation is very useful if one needs repeated
|
||||
evaluation of an expensive function, such as function defined
|
||||
implicitly by an integral or a differential equation. (For
|
||||
example, it could be used to turn a slow mpmath function
|
||||
into a fast machine-precision version of the same.)
|
||||
|
||||
**Examples**
|
||||
|
||||
Here we use :func:`chebyfit` to generate a low-degree approximation
|
||||
of `f(x) = \cos(x)`, valid on the interval `[1, 2]`::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> poly, err = chebyfit(cos, [1, 2], 5, error=True)
|
||||
>>> nprint(poly)
|
||||
[0.00291682, 0.146166, -0.732491, 0.174141, 0.949553]
|
||||
>>> nprint(err, 12)
|
||||
1.61351758081e-5
|
||||
|
||||
The polynomial can be evaluated using ``polyval``::
|
||||
|
||||
>>> nprint(polyval(poly, 1.6), 12)
|
||||
-0.0291858904138
|
||||
>>> nprint(cos(1.6), 12)
|
||||
-0.0291995223013
|
||||
|
||||
Sampling the true error at 1000 points shows that the error
|
||||
estimate generated by ``chebyfit`` is remarkably good::
|
||||
|
||||
>>> error = lambda x: abs(cos(x) - polyval(poly, x))
|
||||
>>> nprint(max([error(1+n/1000.) for n in range(1000)]), 12)
|
||||
1.61349954245e-5
|
||||
|
||||
**Choice of degree**
|
||||
|
||||
The degree `N` can be set arbitrarily high, to obtain an
|
||||
arbitrarily good approximation. As a rule of thumb, an
|
||||
`N`-term Chebyshev approximation is good to `N/(b-a)` decimal
|
||||
places on a unit interval (although this depends on how
|
||||
well-behaved `f` is). The cost grows accordingly: ``chebyfit``
|
||||
evaluates the function `(N^2)/2` times to compute the
|
||||
coefficients and an additional `N` times to estimate the error.
|
||||
|
||||
**Possible issues**
|
||||
|
||||
One should be careful to use a sufficiently high working
|
||||
precision both when calling ``chebyfit`` and when evaluating
|
||||
the resulting polynomial, as the polynomial is sometimes
|
||||
ill-conditioned. It is for example difficult to reach
|
||||
15-digit accuracy when evaluating the polynomial using
|
||||
machine precision floats, no matter the theoretical
|
||||
accuracy of the polynomial. (The option to return the
|
||||
coefficients in Chebyshev form should be made available
|
||||
in the future.)
|
||||
|
||||
It is important to note the Chebyshev approximation works
|
||||
poorly if `f` is not smooth. A function containing singularities,
|
||||
rapid oscillation, etc can be approximated more effectively by
|
||||
multiplying it by a weight function that cancels out the
|
||||
nonsmooth features, or by dividing the interval into several
|
||||
segments.
|
||||
"""
|
||||
a, b = ctx._as_points(interval)
|
||||
orig = ctx.prec
|
||||
try:
|
||||
ctx.prec = orig + int(N**0.5) + 20
|
||||
c = [chebcoeff(ctx,f,a,b,k,N) for k in range(N)]
|
||||
d = [ctx.zero] * N
|
||||
d[0] = -c[0]/2
|
||||
h = ctx.mpf(0.5)
|
||||
T = chebT(ctx, ctx.mpf(2)/(b-a), ctx.mpf(-1)*(b+a)/(b-a))
|
||||
for k in range(N):
|
||||
Tk = T.next()
|
||||
for i in range(len(Tk)):
|
||||
d[i] += c[k]*Tk[i]
|
||||
d = d[::-1]
|
||||
# Estimate maximum error
|
||||
err = ctx.zero
|
||||
for k in range(N):
|
||||
x = ctx.cos(ctx.pi*k/N) * (b-a)*h + (b+a)*h
|
||||
err = max(err, abs(f(x) - ctx.polyval(d, x)))
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
if error:
|
||||
return d, +err
|
||||
else:
|
||||
return d
|
||||
|
||||
@defun
|
||||
def fourier(ctx, f, interval, N):
|
||||
r"""
|
||||
Computes the Fourier series of degree `N` of the given function
|
||||
on the interval `[a, b]`. More precisely, :func:`fourier` returns
|
||||
two lists `(c, s)` of coefficients (the cosine series and sine
|
||||
series, respectively), such that
|
||||
|
||||
.. math ::
|
||||
|
||||
f(x) \sim \sum_{k=0}^N
|
||||
c_k \cos(k m) + s_k \sin(k m)
|
||||
|
||||
where `m = 2 \pi / (b-a)`.
|
||||
|
||||
Note that many texts define the first coefficient as `2 c_0` instead
|
||||
of `c_0`. The easiest way to evaluate the computed series correctly
|
||||
is to pass it to :func:`fourierval`.
|
||||
|
||||
**Examples**
|
||||
|
||||
The function `f(x) = x` has a simple Fourier series on the standard
|
||||
interval `[-\pi, \pi]`. The cosine coefficients are all zero (because
|
||||
the function has odd symmetry), and the sine coefficients are
|
||||
rational numbers::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> c, s = fourier(lambda x: x, [-pi, pi], 5)
|
||||
>>> nprint(c)
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
>>> nprint(s)
|
||||
[0.0, 2.0, -1.0, 0.666667, -0.5, 0.4]
|
||||
|
||||
This computes a Fourier series of a nonsymmetric function on
|
||||
a nonstandard interval::
|
||||
|
||||
>>> I = [-1, 1.5]
|
||||
>>> f = lambda x: x**2 - 4*x + 1
|
||||
>>> cs = fourier(f, I, 4)
|
||||
>>> nprint(cs[0])
|
||||
[0.583333, 1.12479, -1.27552, 0.904708, -0.441296]
|
||||
>>> nprint(cs[1])
|
||||
[0.0, -2.6255, 0.580905, 0.219974, -0.540057]
|
||||
|
||||
It is instructive to plot a function along with its truncated
|
||||
Fourier series::
|
||||
|
||||
>>> plot([f, lambda x: fourierval(cs, I, x)], I) #doctest: +SKIP
|
||||
|
||||
Fourier series generally converge slowly (and may not converge
|
||||
pointwise). For example, if `f(x) = \cosh(x)`, a 10-term Fourier
|
||||
series gives an `L^2` error corresponding to 2-digit accuracy::
|
||||
|
||||
>>> I = [-1, 1]
|
||||
>>> cs = fourier(cosh, I, 9)
|
||||
>>> g = lambda x: (cosh(x) - fourierval(cs, I, x))**2
|
||||
>>> nprint(sqrt(quad(g, I)))
|
||||
0.00467963
|
||||
|
||||
:func:`fourier` uses numerical quadrature. For nonsmooth functions,
|
||||
the accuracy (and speed) can be improved by including all singular
|
||||
points in the interval specification::
|
||||
|
||||
>>> nprint(fourier(abs, [-1, 1], 0), 10)
|
||||
([0.5000441648], [0.0])
|
||||
>>> nprint(fourier(abs, [-1, 0, 1], 0), 10)
|
||||
([0.5], [0.0])
|
||||
|
||||
"""
|
||||
interval = ctx._as_points(interval)
|
||||
a = interval[0]
|
||||
b = interval[-1]
|
||||
L = b-a
|
||||
cos_series = []
|
||||
sin_series = []
|
||||
cutoff = ctx.eps*10
|
||||
for n in xrange(N+1):
|
||||
m = 2*n*ctx.pi/L
|
||||
an = 2*ctx.quadgl(lambda t: f(t)*ctx.cos(m*t), interval)/L
|
||||
bn = 2*ctx.quadgl(lambda t: f(t)*ctx.sin(m*t), interval)/L
|
||||
if n == 0:
|
||||
an /= 2
|
||||
if abs(an) < cutoff: an = ctx.zero
|
||||
if abs(bn) < cutoff: bn = ctx.zero
|
||||
cos_series.append(an)
|
||||
sin_series.append(bn)
|
||||
return cos_series, sin_series
|
||||
|
||||
@defun
|
||||
def fourierval(ctx, series, interval, x):
|
||||
"""
|
||||
Evaluates a Fourier series (in the format computed by
|
||||
by :func:`fourier` for the given interval) at the point `x`.
|
||||
|
||||
The series should be a pair `(c, s)` where `c` is the
|
||||
cosine series and `s` is the sine series. The two lists
|
||||
need not have the same length.
|
||||
"""
|
||||
cs, ss = series
|
||||
ab = ctx._as_points(interval)
|
||||
a = interval[0]
|
||||
b = interval[-1]
|
||||
m = 2*ctx.pi/(ab[-1]-ab[0])
|
||||
s = ctx.zero
|
||||
s += ctx.fsum(cs[n]*ctx.cos(m*n*x) for n in xrange(len(cs)) if cs[n])
|
||||
s += ctx.fsum(ss[n]*ctx.sin(m*n*x) for n in xrange(len(ss)) if ss[n])
|
||||
return s
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
class CalculusMethods(object):
|
||||
pass
|
||||
|
||||
def defun(f):
|
||||
setattr(CalculusMethods, f.__name__, f)
|
||||
|
|
@ -1,438 +0,0 @@
|
|||
from calculus import defun
|
||||
|
||||
#----------------------------------------------------------------------------#
|
||||
# Differentiation #
|
||||
#----------------------------------------------------------------------------#
|
||||
|
||||
@defun
|
||||
def difference_delta(ctx, s, n):
|
||||
r"""
|
||||
Given a sequence `(s_k)` containing at least `n+1` items, returns the
|
||||
`n`-th forward difference,
|
||||
|
||||
.. math ::
|
||||
|
||||
\Delta^n = \sum_{k=0}^{\infty} (-1)^{k+n} {n \choose k} s_k.
|
||||
"""
|
||||
n = int(n)
|
||||
d = ctx.zero
|
||||
b = (-1) ** (n & 1)
|
||||
for k in xrange(n+1):
|
||||
d += b * s[k]
|
||||
b = (b * (k-n)) // (k+1)
|
||||
return d
|
||||
|
||||
@defun
|
||||
def diff(ctx, f, x, n=1, method='step', scale=1, direction=0):
|
||||
r"""
|
||||
Numerically computes the derivative of `f`, `f'(x)`. Optionally,
|
||||
computes the `n`-th derivative, `f^{(n)}(x)`, for any order `n`.
|
||||
|
||||
**Basic examples**
|
||||
|
||||
Derivatives of a simple function::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> diff(lambda x: x**2 + x, 1.0)
|
||||
3.0
|
||||
>>> diff(lambda x: x**2 + x, 1.0, 2)
|
||||
2.0
|
||||
>>> diff(lambda x: x**2 + x, 1.0, 3)
|
||||
0.0
|
||||
|
||||
The exponential function is invariant under differentiation::
|
||||
|
||||
>>> nprint([diff(exp, 3, n) for n in range(5)])
|
||||
[20.0855, 20.0855, 20.0855, 20.0855, 20.0855]
|
||||
|
||||
**Method**
|
||||
|
||||
One of two differentiation algorithms can be chosen with the
|
||||
``method`` keyword argument. The two options are ``'step'``,
|
||||
and ``'quad'``. The default method is ``'step'``.
|
||||
|
||||
``'step'``:
|
||||
|
||||
The derivative is computed using a finite difference
|
||||
approximation, with a small step h. This requires n+1 function
|
||||
evaluations and must be performed at (n+1) times the target
|
||||
precision. Accordingly, f must support fast evaluation at high
|
||||
precision.
|
||||
|
||||
``'quad'``:
|
||||
|
||||
The derivative is computed using complex
|
||||
numerical integration. This requires a larger number of function
|
||||
evaluations, but the advantage is that not much extra precision
|
||||
is required. For high order derivatives, this method may thus
|
||||
be faster if f is very expensive to evaluate at high precision.
|
||||
|
||||
With ``'quad'`` the result is likely to have a small imaginary
|
||||
component even if the derivative is actually real::
|
||||
|
||||
>>> diff(sqrt, 1, method='quad') # doctest:+ELLIPSIS
|
||||
(0.5 - 9.44...e-27j)
|
||||
|
||||
**Scale**
|
||||
|
||||
The scale option specifies the scale of variation of f. The step
|
||||
size in the finite difference is taken to be approximately
|
||||
eps*scale. Thus, for example if `f(x) = \cos(1000 x)`, the scale
|
||||
should be set to 1/1000 and if `f(x) = \cos(x/1000)`, the scale
|
||||
should be 1000. By default, scale = 1.
|
||||
|
||||
(In practice, the default scale will work even for `\cos(1000 x)` or
|
||||
`\cos(x/1000)`. Changing this parameter is a good idea if the scale
|
||||
is something *preposterous*.)
|
||||
|
||||
If numerical integration is used, the radius of integration is
|
||||
taken to be equal to scale/2. Note that f must not have any
|
||||
singularities within the circle of radius scale/2 centered around
|
||||
x. If possible, a larger scale value is preferable because it
|
||||
typically makes the integration faster and more accurate.
|
||||
|
||||
**Direction**
|
||||
|
||||
By default, :func:`diff` uses a central difference approximation.
|
||||
This corresponds to direction=0. Alternatively, it can compute a
|
||||
left difference (direction=-1) or right difference (direction=1).
|
||||
This is useful for computing left- or right-sided derivatives
|
||||
of nonsmooth functions:
|
||||
|
||||
>>> diff(abs, 0, direction=0)
|
||||
0.0
|
||||
>>> diff(abs, 0, direction=1)
|
||||
1.0
|
||||
>>> diff(abs, 0, direction=-1)
|
||||
-1.0
|
||||
|
||||
More generally, if the direction is nonzero, a right difference
|
||||
is computed where the step size is multiplied by sign(direction).
|
||||
For example, with direction=+j, the derivative from the positive
|
||||
imaginary direction will be computed.
|
||||
|
||||
This option only makes sense with method='step'. If integration
|
||||
is used, it is assumed that f is analytic, implying that the
|
||||
derivative is the same in all directions.
|
||||
|
||||
"""
|
||||
if n == 0:
|
||||
return f(ctx.convert(x))
|
||||
orig = ctx.prec
|
||||
try:
|
||||
if method == 'step':
|
||||
ctx.prec = (orig+20) * (n+1)
|
||||
h = ctx.ldexp(scale, -orig-10)
|
||||
# Applying the finite difference formula recursively n times,
|
||||
# we get a step sum weighted by a row of binomial coefficients
|
||||
# Directed: steps x, x+h, ... x+n*h
|
||||
if direction:
|
||||
h *= ctx.sign(direction)
|
||||
steps = xrange(n+1)
|
||||
norm = h**n
|
||||
# Central: steps x-n*h, x-(n-2)*h ..., x, ..., x+(n-2)*h, x+n*h
|
||||
else:
|
||||
steps = xrange(-n, n+1, 2)
|
||||
norm = (2*h)**n
|
||||
v = ctx.difference_delta([f(x+k*h) for k in steps], n)
|
||||
v = v / norm
|
||||
elif method == 'quad':
|
||||
ctx.prec += 10
|
||||
radius = ctx.mpf(scale)/2
|
||||
def g(t):
|
||||
rei = radius*ctx.expj(t)
|
||||
z = x + rei
|
||||
return f(z) / rei**n
|
||||
d = ctx.quadts(g, [0, 2*ctx.pi])
|
||||
v = d * ctx.factorial(n) / (2*ctx.pi)
|
||||
else:
|
||||
raise ValueError("unknown method: %r" % method)
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
return +v
|
||||
|
||||
@defun
|
||||
def diffs(ctx, f, x, n=None, method='step', scale=1, direction=0):
|
||||
r"""
|
||||
Returns a generator that yields the sequence of derivatives
|
||||
|
||||
.. math ::
|
||||
|
||||
f(x), f'(x), f''(x), \ldots, f^{(k)}(x), \ldots
|
||||
|
||||
With ``method='step'``, :func:`diffs` uses only `O(k)`
|
||||
function evaluations to generate the first `k` derivatives,
|
||||
rather than the roughly `O(k^2)` evaluations
|
||||
required if one calls :func:`diff` `k` separate times.
|
||||
|
||||
With `n < \infty`, the generator stops as soon as the
|
||||
`n`-th derivative has been generated. If the exact number of
|
||||
needed derivatives is known in advance, this is further
|
||||
slightly more efficient.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15
|
||||
>>> nprint(list(diffs(cos, 1, 5)))
|
||||
[0.540302, -0.841471, -0.540302, 0.841471, 0.540302, -0.841471]
|
||||
>>> for i, d in zip(range(6), diffs(cos, 1)): print i, d
|
||||
...
|
||||
0 0.54030230586814
|
||||
1 -0.841470984807897
|
||||
2 -0.54030230586814
|
||||
3 0.841470984807897
|
||||
4 0.54030230586814
|
||||
5 -0.841470984807897
|
||||
|
||||
"""
|
||||
if n is None:
|
||||
n = ctx.inf
|
||||
else:
|
||||
n = int(n)
|
||||
|
||||
if method != 'step':
|
||||
k = 0
|
||||
while k < n:
|
||||
yield ctx.diff(f, x, k)
|
||||
k += 1
|
||||
return
|
||||
|
||||
targetprec = ctx.prec
|
||||
|
||||
def getvalues(m):
|
||||
callprec = ctx.prec
|
||||
try:
|
||||
ctx.prec = workprec = (targetprec+20) * (m+1)
|
||||
h = ctx.ldexp(scale, -targetprec-10)
|
||||
if direction:
|
||||
h *= ctx.sign(direction)
|
||||
y = [f(x+h*k) for k in xrange(m+1)]
|
||||
hnorm = h
|
||||
else:
|
||||
y = [f(x+h*k) for k in xrange(-m, m+1, 2)]
|
||||
hnorm = 2*h
|
||||
return y, hnorm, workprec
|
||||
finally:
|
||||
ctx.prec = callprec
|
||||
|
||||
yield f(ctx.convert(x))
|
||||
if n < 1:
|
||||
return
|
||||
|
||||
if n == ctx.inf:
|
||||
A, B = 1, 2
|
||||
else:
|
||||
A, B = 1, n+1
|
||||
|
||||
while 1:
|
||||
y, hnorm, workprec = getvalues(B)
|
||||
for k in xrange(A, B):
|
||||
try:
|
||||
callprec = ctx.prec
|
||||
ctx.prec = workprec
|
||||
d = ctx.difference_delta(y, k) / hnorm**k
|
||||
finally:
|
||||
ctx.prec = callprec
|
||||
yield +d
|
||||
if k >= n:
|
||||
return
|
||||
A, B = B, int(A*1.4+1)
|
||||
B = min(B, n)
|
||||
|
||||
@defun
|
||||
def differint(ctx, f, x, n=1, x0=0):
|
||||
r"""
|
||||
Calculates the Riemann-Liouville differintegral, or fractional
|
||||
derivative, defined by
|
||||
|
||||
.. math ::
|
||||
|
||||
\,_{x_0}{\mathbb{D}}^n_xf(x) \frac{1}{\Gamma(m-n)} \frac{d^m}{dx^m}
|
||||
\int_{x_0}^{x}(x-t)^{m-n-1}f(t)dt
|
||||
|
||||
where `f` is a given (presumably well-behaved) function,
|
||||
`x` is the evaluation point, `n` is the order, and `x_0` is
|
||||
the reference point of integration (`m` is an arbitrary
|
||||
parameter selected automatically).
|
||||
|
||||
With `n = 1`, this is just the standard derivative `f'(x)`; with `n = 2`,
|
||||
the second derivative `f''(x)`, etc. With `n = -1`, it gives
|
||||
`\int_{x_0}^x f(t) dt`, with `n = -2`
|
||||
it gives `\int_{x_0}^x \left( \int_{x_0}^t f(u) du \right) dt`, etc.
|
||||
|
||||
As `n` is permitted to be any number, this operator generalizes
|
||||
iterated differentiation and iterated integration to a single
|
||||
operator with a continuous order parameter.
|
||||
|
||||
**Examples**
|
||||
|
||||
There is an exact formula for the fractional derivative of a
|
||||
monomial `x^p`, which may be used as a reference. For example,
|
||||
the following gives a half-derivative (order 0.5)::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> x = mpf(3); p = 2; n = 0.5
|
||||
>>> differint(lambda t: t**p, x, n)
|
||||
7.81764019044672
|
||||
>>> gamma(p+1)/gamma(p-n+1) * x**(p-n)
|
||||
7.81764019044672
|
||||
|
||||
Another useful test function is the exponential function, whose
|
||||
integration / differentiation formula easy generalizes
|
||||
to arbitrary order. Here we first compute a third derivative,
|
||||
and then a triply nested integral. (The reference point `x_0`
|
||||
is set to `-\infty` to avoid nonzero endpoint terms.)::
|
||||
|
||||
>>> differint(lambda x: exp(pi*x), -1.5, 3)
|
||||
0.278538406900792
|
||||
>>> exp(pi*-1.5) * pi**3
|
||||
0.278538406900792
|
||||
>>> differint(lambda x: exp(pi*x), 3.5, -3, -inf)
|
||||
1922.50563031149
|
||||
>>> exp(pi*3.5) / pi**3
|
||||
1922.50563031149
|
||||
|
||||
However, for noninteger `n`, the differentiation formula for the
|
||||
exponential function must be modified to give the same result as the
|
||||
Riemann-Liouville differintegral::
|
||||
|
||||
>>> x = mpf(3.5)
|
||||
>>> c = pi
|
||||
>>> n = 1+2*j
|
||||
>>> differint(lambda x: exp(c*x), x, n)
|
||||
(-123295.005390743 + 140955.117867654j)
|
||||
>>> x**(-n) * exp(c)**x * (x*c)**n * gammainc(-n, 0, x*c) / gamma(-n)
|
||||
(-123295.005390743 + 140955.117867654j)
|
||||
|
||||
|
||||
"""
|
||||
m = max(int(ctx.ceil(ctx.re(n)))+1, 1)
|
||||
r = m-n-1
|
||||
g = lambda x: ctx.quad(lambda t: (x-t)**r * f(t), [x0, x])
|
||||
return ctx.diff(g, x, m) / ctx.gamma(m-n)
|
||||
|
||||
@defun
|
||||
def diffun(ctx, f, n=1, **options):
|
||||
"""
|
||||
Given a function f, returns a function g(x) that evaluates the nth
|
||||
derivative f^(n)(x)::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> cos2 = diffun(sin)
|
||||
>>> sin2 = diffun(sin, 4)
|
||||
>>> cos(1.3), cos2(1.3)
|
||||
(0.267498828624587, 0.267498828624587)
|
||||
>>> sin(1.3), sin2(1.3)
|
||||
(0.963558185417193, 0.963558185417193)
|
||||
|
||||
The function f must support arbitrary precision evaluation.
|
||||
See :func:`diff` for additional details and supported
|
||||
keyword options.
|
||||
"""
|
||||
if n == 0:
|
||||
return f
|
||||
def g(x):
|
||||
return ctx.diff(f, x, n, **options)
|
||||
return g
|
||||
|
||||
@defun
|
||||
def taylor(ctx, f, x, n, **options):
|
||||
r"""
|
||||
Produces a degree-`n` Taylor polynomial around the point `x` of the
|
||||
given function `f`. The coefficients are returned as a list.
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> nprint(chop(taylor(sin, 0, 5)))
|
||||
[0.0, 1.0, 0.0, -0.166667, 0.0, 0.00833333]
|
||||
|
||||
The coefficients are computed using high-order numerical
|
||||
differentiation. The function must be possible to evaluate
|
||||
to arbitrary precision. See :func:`diff` for additional details
|
||||
and supported keyword options.
|
||||
|
||||
Note that to evaluate the Taylor polynomial as an approximation
|
||||
of `f`, e.g. with :func:`polyval`, the coefficients must be reversed,
|
||||
and the point of the Taylor expansion must be subtracted from
|
||||
the argument:
|
||||
|
||||
>>> p = taylor(exp, 2.0, 10)
|
||||
>>> polyval(p[::-1], 2.5 - 2.0)
|
||||
12.1824939606092
|
||||
>>> exp(2.5)
|
||||
12.1824939607035
|
||||
|
||||
"""
|
||||
return [d/ctx.factorial(i) for i, d in enumerate(ctx.diffs(f, x, n, **options))]
|
||||
|
||||
@defun
|
||||
def pade(ctx, a, L, M):
|
||||
r"""
|
||||
Computes a Pade approximation of degree `(L, M)` to a function.
|
||||
Given at least `L+M+1` Taylor coefficients `a` approximating
|
||||
a function `A(x)`, :func:`pade` returns coefficients of
|
||||
polynomials `P, Q` satisfying
|
||||
|
||||
.. math ::
|
||||
|
||||
P = \sum_{k=0}^L p_k x^k
|
||||
|
||||
Q = \sum_{k=0}^M q_k x^k
|
||||
|
||||
Q_0 = 1
|
||||
|
||||
A(x) Q(x) = P(x) + O(x^{L+M+1})
|
||||
|
||||
`P(x)/Q(x)` can provide a good approximation to an analytic function
|
||||
beyond the radius of convergence of its Taylor series (example
|
||||
from G.A. Baker 'Essentials of Pade Approximants' Academic Press,
|
||||
Ch.1A)::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> one = mpf(1)
|
||||
>>> def f(x):
|
||||
... return sqrt((one + 2*x)/(one + x))
|
||||
...
|
||||
>>> a = taylor(f, 0, 6)
|
||||
>>> p, q = pade(a, 3, 3)
|
||||
>>> x = 10
|
||||
>>> polyval(p[::-1], x)/polyval(q[::-1], x)
|
||||
1.38169105566806
|
||||
>>> f(x)
|
||||
1.38169855941551
|
||||
|
||||
"""
|
||||
# To determine L+1 coefficients of P and M coefficients of Q
|
||||
# L+M+1 coefficients of A must be provided
|
||||
assert(len(a) >= L+M+1)
|
||||
|
||||
if M == 0:
|
||||
if L == 0:
|
||||
return [ctx.one], [ctx.one]
|
||||
else:
|
||||
return a[:L+1], [ctx.one]
|
||||
|
||||
# Solve first
|
||||
# a[L]*q[1] + ... + a[L-M+1]*q[M] = -a[L+1]
|
||||
# ...
|
||||
# a[L+M-1]*q[1] + ... + a[L]*q[M] = -a[L+M]
|
||||
A = ctx.matrix(M)
|
||||
for j in range(M):
|
||||
for i in range(min(M, L+j+1)):
|
||||
A[j, i] = a[L+j-i]
|
||||
v = -ctx.matrix(a[(L+1):(L+M+1)])
|
||||
x = ctx.lu_solve(A, v)
|
||||
q = [ctx.one] + list(x)
|
||||
# compute p
|
||||
p = [0]*(L+1)
|
||||
for i in range(L+1):
|
||||
s = a[i]
|
||||
for j in range(1, min(M,i) + 1):
|
||||
s += q[j]*a[i-j]
|
||||
p[i] = s
|
||||
return p, q
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,287 +0,0 @@
|
|||
from bisect import bisect
|
||||
|
||||
class ODEMethods(object):
|
||||
pass
|
||||
|
||||
def ode_taylor(ctx, derivs, x0, y0, tol_prec, n):
|
||||
h = tol = ctx.ldexp(1, -tol_prec)
|
||||
dim = len(y0)
|
||||
xs = [x0]
|
||||
ys = [y0]
|
||||
x = x0
|
||||
y = y0
|
||||
orig = ctx.prec
|
||||
try:
|
||||
ctx.prec = orig*(1+n)
|
||||
# Use n steps with Euler's method to get
|
||||
# evaluation points for derivatives
|
||||
for i in range(n):
|
||||
fxy = derivs(x, y)
|
||||
y = [y[i]+h*fxy[i] for i in xrange(len(y))]
|
||||
x += h
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
# Compute derivatives
|
||||
ser = [[] for d in range(dim)]
|
||||
for j in range(n+1):
|
||||
s = [0]*dim
|
||||
b = (-1) ** (j & 1)
|
||||
k = 1
|
||||
for i in range(j+1):
|
||||
for d in range(dim):
|
||||
s[d] += b * ys[i][d]
|
||||
b = (b * (j-k+1)) // (-k)
|
||||
k += 1
|
||||
scale = h**(-j) / ctx.fac(j)
|
||||
for d in range(dim):
|
||||
s[d] = s[d] * scale
|
||||
ser[d].append(s[d])
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
# Estimate radius for which we can get full accuracy.
|
||||
# XXX: do this right for zeros
|
||||
radius = ctx.one
|
||||
for ts in ser:
|
||||
if ts[-1]:
|
||||
radius = min(radius, ctx.nthroot(tol/abs(ts[-1]), n))
|
||||
radius /= 2 # XXX
|
||||
return ser, x0+radius
|
||||
|
||||
def odefun(ctx, F, x0, y0, tol=None, degree=None, method='taylor', verbose=False):
|
||||
r"""
|
||||
Returns a function `y(x) = [y_0(x), y_1(x), \ldots, y_n(x)]`
|
||||
that is a numerical solution of the `n+1`-dimensional first-order
|
||||
ordinary differential equation (ODE) system
|
||||
|
||||
.. math ::
|
||||
|
||||
y_0'(x) = F_0(x, [y_0(x), y_1(x), \ldots, y_n(x)])
|
||||
|
||||
y_1'(x) = F_1(x, [y_0(x), y_1(x), \ldots, y_n(x)])
|
||||
|
||||
\vdots
|
||||
|
||||
y_n'(x) = F_n(x, [y_0(x), y_1(x), \ldots, y_n(x)])
|
||||
|
||||
The derivatives are specified by the vector-valued function
|
||||
*F* that evaluates
|
||||
`[y_0', \ldots, y_n'] = F(x, [y_0, \ldots, y_n])`.
|
||||
The initial point `x_0` is specified by the scalar argument *x0*,
|
||||
and the initial value `y(x_0) = [y_0(x_0), \ldots, y_n(x_0)]` is
|
||||
specified by the vector argument *y0*.
|
||||
|
||||
For convenience, if the system is one-dimensional, you may optionally
|
||||
provide just a scalar value for *y0*. In this case, *F* should accept
|
||||
a scalar *y* argument and return a scalar. The solution function
|
||||
*y* will return scalar values instead of length-1 vectors.
|
||||
|
||||
Evaluation of the solution function `y(x)` is permitted
|
||||
for any `x \ge x_0`.
|
||||
|
||||
A high-order ODE can be solved by transforming it into first-order
|
||||
vector form. This transformation is described in standard texts
|
||||
on ODEs. Examples will also be given below.
|
||||
|
||||
**Options, speed and accuracy**
|
||||
|
||||
By default, :func:`odefun` uses a high-order Taylor series
|
||||
method. For reasonably well-behaved problems, the solution will
|
||||
be fully accurate to within the working precision. Note that
|
||||
*F* must be possible to evaluate to very high precision
|
||||
for the generation of Taylor series to work.
|
||||
|
||||
To get a faster but less accurate solution, you can set a large
|
||||
value for *tol* (which defaults roughly to *eps*). If you just
|
||||
want to plot the solution or perform a basic simulation,
|
||||
*tol = 0.01* is likely sufficient.
|
||||
|
||||
The *degree* argument controls the degree of the solver (with
|
||||
*method='taylor'*, this is the degree of the Taylor series
|
||||
expansion). A higher degree means that a longer step can be taken
|
||||
before a new local solution must be generated from *F*,
|
||||
meaning that fewer steps are required to get from `x_0` to a given
|
||||
`x_1`. On the other hand, a higher degree also means that each
|
||||
local solution becomes more expensive (i.e., more evaluations of
|
||||
*F* are required per step, and at higher precision).
|
||||
|
||||
The optimal setting therefore involves a tradeoff. Generally,
|
||||
decreasing the *degree* for Taylor series is likely to give faster
|
||||
solution at low precision, while increasing is likely to be better
|
||||
at higher precision.
|
||||
|
||||
The function
|
||||
object returned by :func:`odefun` caches the solutions at all step
|
||||
points and uses polynomial interpolation between step points.
|
||||
Therefore, once `y(x_1)` has been evaluated for some `x_1`,
|
||||
`y(x)` can be evaluated very quickly for any `x_0 \le x \le x_1`.
|
||||
and continuing the evaluation up to `x_2 > x_1` is also fast.
|
||||
|
||||
**Examples of first-order ODEs**
|
||||
|
||||
We will solve the standard test problem `y'(x) = y(x), y(0) = 1`
|
||||
which has explicit solution `y(x) = \exp(x)`::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> f = odefun(lambda x, y: y, 0, 1)
|
||||
>>> for x in [0, 1, 2.5]:
|
||||
... print f(x), exp(x)
|
||||
...
|
||||
1.0 1.0
|
||||
2.71828182845905 2.71828182845905
|
||||
12.1824939607035 12.1824939607035
|
||||
|
||||
The solution with high precision::
|
||||
|
||||
>>> mp.dps = 50
|
||||
>>> f = odefun(lambda x, y: y, 0, 1)
|
||||
>>> f(1)
|
||||
2.7182818284590452353602874713526624977572470937
|
||||
>>> exp(1)
|
||||
2.7182818284590452353602874713526624977572470937
|
||||
|
||||
Using the more general vectorized form, the test problem
|
||||
can be input as (note that *f* returns a 1-element vector)::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> f = odefun(lambda x, y: [y[0]], 0, [1])
|
||||
>>> f(1)
|
||||
[2.71828182845905]
|
||||
|
||||
:func:`odefun` can solve nonlinear ODEs, which are generally
|
||||
impossible (and at best difficult) to solve analytically. As
|
||||
an example of a nonlinear ODE, we will solve `y'(x) = x \sin(y(x))`
|
||||
for `y(0) = \pi/2`. An exact solution happens to be known
|
||||
for this problem, and is given by
|
||||
`y(x) = 2 \tan^{-1}\left(\exp\left(x^2/2\right)\right)`::
|
||||
|
||||
>>> f = odefun(lambda x, y: x*sin(y), 0, pi/2)
|
||||
>>> for x in [2, 5, 10]:
|
||||
... print f(x), 2*atan(exp(mpf(x)**2/2))
|
||||
...
|
||||
2.87255666284091 2.87255666284091
|
||||
3.14158520028345 3.14158520028345
|
||||
3.14159265358979 3.14159265358979
|
||||
|
||||
If `F` is independent of `y`, an ODE can be solved using direct
|
||||
integration. We can therefore obtain a reference solution with
|
||||
:func:`quad`::
|
||||
|
||||
>>> f = lambda x: (1+x**2)/(1+x**3)
|
||||
>>> g = odefun(lambda x, y: f(x), pi, 0)
|
||||
>>> g(2*pi)
|
||||
0.72128263801696
|
||||
>>> quad(f, [pi, 2*pi])
|
||||
0.72128263801696
|
||||
|
||||
**Examples of second-order ODEs**
|
||||
|
||||
We will solve the harmonic oscillator equation `y''(x) + y(x) = 0`.
|
||||
To do this, we introduce the helper functions `y_0 = y, y_1 = y_0'`
|
||||
whereby the original equation can be written as `y_1' + y_0' = 0`. Put
|
||||
together, we get the first-order, two-dimensional vector ODE
|
||||
|
||||
.. math ::
|
||||
|
||||
\begin{cases}
|
||||
y_0' = y_1 \\
|
||||
y_1' = -y_0
|
||||
\end{cases}
|
||||
|
||||
To get a well-defined IVP, we need two initial values. With
|
||||
`y(0) = y_0(0) = 1` and `-y'(0) = y_1(0) = 0`, the problem will of
|
||||
course be solved by `y(x) = y_0(x) = \cos(x)` and
|
||||
`-y'(x) = y_1(x) = \sin(x)`. We check this::
|
||||
|
||||
>>> f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
|
||||
>>> for x in [0, 1, 2.5, 10]:
|
||||
... nprint(f(x), 15)
|
||||
... nprint([cos(x), sin(x)], 15)
|
||||
... print "---"
|
||||
...
|
||||
[1.0, 0.0]
|
||||
[1.0, 0.0]
|
||||
---
|
||||
[0.54030230586814, 0.841470984807897]
|
||||
[0.54030230586814, 0.841470984807897]
|
||||
---
|
||||
[-0.801143615546934, 0.598472144103957]
|
||||
[-0.801143615546934, 0.598472144103957]
|
||||
---
|
||||
[-0.839071529076452, -0.54402111088937]
|
||||
[-0.839071529076452, -0.54402111088937]
|
||||
---
|
||||
|
||||
Note that we get both the sine and the cosine solutions
|
||||
simultaneously.
|
||||
|
||||
**TODO**
|
||||
|
||||
* Better automatic choice of degree and step size
|
||||
* Make determination of Taylor series convergence radius
|
||||
more robust
|
||||
* Allow solution for `x < x_0`
|
||||
* Allow solution for complex `x`
|
||||
* Test for difficult (ill-conditioned) problems
|
||||
* Implement Runge-Kutta and other algorithms
|
||||
|
||||
"""
|
||||
if tol:
|
||||
tol_prec = int(-ctx.log(tol, 2))+10
|
||||
else:
|
||||
tol_prec = ctx.prec+10
|
||||
degree = degree or (3 + int(3*ctx.dps/2.))
|
||||
workprec = ctx.prec + 40
|
||||
try:
|
||||
len(y0)
|
||||
return_vector = True
|
||||
except TypeError:
|
||||
F_ = F
|
||||
F = lambda x, y: [F_(x, y[0])]
|
||||
y0 = [y0]
|
||||
return_vector = False
|
||||
ser, xb = ode_taylor(ctx, F, x0, y0, tol_prec, degree)
|
||||
series_boundaries = [x0, xb]
|
||||
series_data = [(ser, x0, xb)]
|
||||
# We will be working with vectors of Taylor series
|
||||
def mpolyval(ser, a):
|
||||
return [ctx.polyval(s[::-1], a) for s in ser]
|
||||
# Find nearest expansion point; compute if necessary
|
||||
def get_series(x):
|
||||
if x < x0:
|
||||
raise ValueError
|
||||
n = bisect(series_boundaries, x)
|
||||
if n < len(series_boundaries):
|
||||
return series_data[n-1]
|
||||
while 1:
|
||||
ser, xa, xb = series_data[-1]
|
||||
if verbose:
|
||||
print "Computing Taylor series for [%f, %f]" % (xa, xb)
|
||||
y = mpolyval(ser, xb-xa)
|
||||
xa = xb
|
||||
ser, xb = ode_taylor(ctx, F, xb, y, tol_prec, degree)
|
||||
series_boundaries.append(xb)
|
||||
series_data.append((ser, xa, xb))
|
||||
if x <= xb:
|
||||
return series_data[-1]
|
||||
# Evaluation function
|
||||
def interpolant(x):
|
||||
x = ctx.convert(x)
|
||||
orig = ctx.prec
|
||||
try:
|
||||
ctx.prec = workprec
|
||||
ser, xa, xb = get_series(x)
|
||||
y = mpolyval(ser, x-xa)
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
if return_vector:
|
||||
return [+yk for yk in y]
|
||||
else:
|
||||
return +y[0]
|
||||
return interpolant
|
||||
|
||||
ODEMethods.odefun = odefun
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,189 +0,0 @@
|
|||
from calculus import defun
|
||||
|
||||
#----------------------------------------------------------------------------#
|
||||
# Polynomials #
|
||||
#----------------------------------------------------------------------------#
|
||||
|
||||
# XXX: extra precision
|
||||
@defun
|
||||
def polyval(ctx, coeffs, x, derivative=False):
|
||||
r"""
|
||||
Given coefficients `[c_n, \ldots, c_2, c_1, c_0]` and a number `x`,
|
||||
:func:`polyval` evaluates the polynomial
|
||||
|
||||
.. math ::
|
||||
|
||||
P(x) = c_n x^n + \ldots + c_2 x^2 + c_1 x + c_0.
|
||||
|
||||
If *derivative=True* is set, :func:`polyval` simultaneously
|
||||
evaluates `P(x)` with the derivative, `P'(x)`, and returns the
|
||||
tuple `(P(x), P'(x))`.
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.pretty = True
|
||||
>>> polyval([3, 0, 2], 0.5)
|
||||
2.75
|
||||
>>> polyval([3, 0, 2], 0.5, derivative=True)
|
||||
(2.75, 3.0)
|
||||
|
||||
The coefficients and the evaluation point may be any combination
|
||||
of real or complex numbers.
|
||||
"""
|
||||
if not coeffs:
|
||||
return ctx.zero
|
||||
p = ctx.convert(coeffs[0])
|
||||
q = ctx.zero
|
||||
for c in coeffs[1:]:
|
||||
if derivative:
|
||||
q = p + x*q
|
||||
p = c + x*p
|
||||
if derivative:
|
||||
return p, q
|
||||
else:
|
||||
return p
|
||||
|
||||
@defun
|
||||
def polyroots(ctx, coeffs, maxsteps=50, cleanup=True, extraprec=10, error=False):
|
||||
"""
|
||||
Computes all roots (real or complex) of a given polynomial. The roots are
|
||||
returned as a sorted list, where real roots appear first followed by
|
||||
complex conjugate roots as adjacent elements. The polynomial should be
|
||||
given as a list of coefficients, in the format used by :func:`polyval`.
|
||||
The leading coefficient must be nonzero.
|
||||
|
||||
With *error=True*, :func:`polyroots` returns a tuple *(roots, err)* where
|
||||
*err* is an estimate of the maximum error among the computed roots.
|
||||
|
||||
**Examples**
|
||||
|
||||
Finding the three real roots of `x^3 - x^2 - 14x + 24`::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> nprint(polyroots([1,-1,-14,24]), 4)
|
||||
[-4.0, 2.0, 3.0]
|
||||
|
||||
Finding the two complex conjugate roots of `4x^2 + 3x + 2`, with an
|
||||
error estimate::
|
||||
|
||||
>>> roots, err = polyroots([4,3,2], error=True)
|
||||
>>> for r in roots:
|
||||
... print r
|
||||
...
|
||||
(-0.375 + 0.59947894041409j)
|
||||
(-0.375 - 0.59947894041409j)
|
||||
>>>
|
||||
>>> err
|
||||
2.22044604925031e-16
|
||||
>>>
|
||||
>>> polyval([4,3,2], roots[0])
|
||||
(2.22044604925031e-16 + 0.0j)
|
||||
>>> polyval([4,3,2], roots[1])
|
||||
(2.22044604925031e-16 + 0.0j)
|
||||
|
||||
The following example computes all the 5th roots of unity; that is,
|
||||
the roots of `x^5 - 1`::
|
||||
|
||||
>>> mp.dps = 20
|
||||
>>> for r in polyroots([1, 0, 0, 0, 0, -1]):
|
||||
... print r
|
||||
...
|
||||
1.0
|
||||
(-0.8090169943749474241 + 0.58778525229247312917j)
|
||||
(-0.8090169943749474241 - 0.58778525229247312917j)
|
||||
(0.3090169943749474241 + 0.95105651629515357212j)
|
||||
(0.3090169943749474241 - 0.95105651629515357212j)
|
||||
|
||||
**Precision and conditioning**
|
||||
|
||||
Provided there are no repeated roots, :func:`polyroots` can typically
|
||||
compute all roots of an arbitrary polynomial to high precision::
|
||||
|
||||
>>> mp.dps = 60
|
||||
>>> for r in polyroots([1, 0, -10, 0, 1]):
|
||||
... print r
|
||||
...
|
||||
-3.14626436994197234232913506571557044551247712918732870123249
|
||||
-0.317837245195782244725757617296174288373133378433432554879127
|
||||
0.317837245195782244725757617296174288373133378433432554879127
|
||||
3.14626436994197234232913506571557044551247712918732870123249
|
||||
>>>
|
||||
>>> sqrt(3) + sqrt(2)
|
||||
3.14626436994197234232913506571557044551247712918732870123249
|
||||
>>> sqrt(3) - sqrt(2)
|
||||
0.317837245195782244725757617296174288373133378433432554879127
|
||||
|
||||
**Algorithm**
|
||||
|
||||
:func:`polyroots` implements the Durand-Kerner method [1], which
|
||||
uses complex arithmetic to locate all roots simultaneously.
|
||||
The Durand-Kerner method can be viewed as approximately performing
|
||||
simultaneous Newton iteration for all the roots. In particular,
|
||||
the convergence to simple roots is quadratic, just like Newton's
|
||||
method.
|
||||
|
||||
Although all roots are internally calculated using complex arithmetic,
|
||||
any root found to have an imaginary part smaller than the estimated
|
||||
numerical error is truncated to a real number. Real roots are placed
|
||||
first in the returned list, sorted by value. The remaining complex
|
||||
roots are sorted by real their parts so that conjugate roots end up
|
||||
next to each other.
|
||||
|
||||
**References**
|
||||
|
||||
1. http://en.wikipedia.org/wiki/Durand-Kerner_method
|
||||
|
||||
"""
|
||||
if len(coeffs) <= 1:
|
||||
if not coeffs or not coeffs[0]:
|
||||
raise ValueError("Input to polyroots must not be the zero polynomial")
|
||||
# Constant polynomial with no roots
|
||||
return []
|
||||
|
||||
orig = ctx.prec
|
||||
weps = +ctx.eps
|
||||
try:
|
||||
ctx.prec += 10
|
||||
tol = ctx.eps * 128
|
||||
deg = len(coeffs) - 1
|
||||
# Must be monic
|
||||
lead = ctx.convert(coeffs[0])
|
||||
if lead == 1:
|
||||
coeffs = map(ctx.convert, coeffs)
|
||||
else:
|
||||
coeffs = [c/lead for c in coeffs]
|
||||
f = lambda x: ctx.polyval(coeffs, x)
|
||||
roots = [ctx.mpc((0.4+0.9j)**n) for n in xrange(deg)]
|
||||
err = [ctx.one for n in xrange(deg)]
|
||||
# Durand-Kerner iteration until convergence
|
||||
for step in xrange(maxsteps):
|
||||
if abs(max(err)) < tol:
|
||||
break
|
||||
for i in xrange(deg):
|
||||
if not abs(err[i]) < tol:
|
||||
p = roots[i]
|
||||
x = f(p)
|
||||
for j in range(deg):
|
||||
if i != j:
|
||||
try:
|
||||
x /= (p-roots[j])
|
||||
except ZeroDivisionError:
|
||||
continue
|
||||
roots[i] = p - x
|
||||
err[i] = abs(x)
|
||||
# Remove small imaginary parts
|
||||
if cleanup:
|
||||
for i in xrange(deg):
|
||||
if abs(ctx._im(roots[i])) < weps:
|
||||
roots[i] = roots[i].real
|
||||
elif abs(ctx._re(roots[i])) < weps:
|
||||
roots[i] = roots[i].imag * 1j
|
||||
roots.sort(key=lambda x: (abs(ctx._im(x)), ctx._re(x)))
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
if error:
|
||||
err = max(err)
|
||||
err = max(err, ctx.ldexp(1, -orig+1))
|
||||
return [+r for r in roots], +err
|
||||
else:
|
||||
return [+r for r in roots]
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,8 +0,0 @@
|
|||
# The py library is part of the "py.test" testing suite (python-codespeak-lib
|
||||
# on Debian), see http://codespeak.net/py/
|
||||
|
||||
import py
|
||||
|
||||
#this makes py.test put mpath directory into the sys.path, so that we can
|
||||
#"import mpmath" from tests nicely
|
||||
rootdir = py.magic.autopath().dirpath()
|
||||
|
|
@ -1,324 +0,0 @@
|
|||
from operator import gt, lt
|
||||
|
||||
from functions.functions import SpecialFunctions
|
||||
from functions.rszeta import RSCache
|
||||
from calculus.quadrature import QuadratureMethods
|
||||
from calculus.calculus import CalculusMethods
|
||||
from calculus.optimization import OptimizationMethods
|
||||
from calculus.odes import ODEMethods
|
||||
from matrices.matrices import MatrixMethods
|
||||
from matrices.calculus import MatrixCalculusMethods
|
||||
from matrices.linalg import LinearAlgebraMethods
|
||||
from identification import IdentificationMethods
|
||||
from visualization import VisualizationMethods
|
||||
|
||||
import libmp
|
||||
|
||||
class Context(object):
|
||||
pass
|
||||
|
||||
class StandardBaseContext(Context,
|
||||
SpecialFunctions,
|
||||
RSCache,
|
||||
QuadratureMethods,
|
||||
CalculusMethods,
|
||||
MatrixMethods,
|
||||
MatrixCalculusMethods,
|
||||
LinearAlgebraMethods,
|
||||
IdentificationMethods,
|
||||
OptimizationMethods,
|
||||
ODEMethods,
|
||||
VisualizationMethods):
|
||||
|
||||
NoConvergence = libmp.NoConvergence
|
||||
ComplexResult = libmp.ComplexResult
|
||||
|
||||
def __init__(ctx):
|
||||
ctx._aliases = {}
|
||||
# Call those that need preinitialization (e.g. for wrappers)
|
||||
SpecialFunctions.__init__(ctx)
|
||||
RSCache.__init__(ctx)
|
||||
QuadratureMethods.__init__(ctx)
|
||||
CalculusMethods.__init__(ctx)
|
||||
MatrixMethods.__init__(ctx)
|
||||
|
||||
def _init_aliases(ctx):
|
||||
for alias, value in ctx._aliases.items():
|
||||
try:
|
||||
setattr(ctx, alias, getattr(ctx, value))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
_fixed_precision = False
|
||||
|
||||
# XXX
|
||||
verbose = False
|
||||
|
||||
def warn(ctx, msg):
|
||||
print "Warning:", msg
|
||||
|
||||
def bad_domain(ctx, msg):
|
||||
raise ValueError(msg)
|
||||
|
||||
def _re(ctx, x):
|
||||
if hasattr(x, "real"):
|
||||
return x.real
|
||||
return x
|
||||
|
||||
def _im(ctx, x):
|
||||
if hasattr(x, "imag"):
|
||||
return x.imag
|
||||
return ctx.zero
|
||||
|
||||
def chop(ctx, x, tol=None):
|
||||
"""
|
||||
Chops off small real or imaginary parts, or converts
|
||||
numbers close to zero to exact zeros. The input can be a
|
||||
single number or an iterable::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> chop(5+1e-10j, tol=1e-9)
|
||||
mpf('5.0')
|
||||
>>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2]))
|
||||
[1.0, 0.0, 3.0, -4.0, 2.0]
|
||||
|
||||
The tolerance defaults to ``100*eps``.
|
||||
"""
|
||||
if tol is None:
|
||||
tol = 100*ctx.eps
|
||||
try:
|
||||
x = ctx.convert(x)
|
||||
absx = abs(x)
|
||||
if abs(x) < tol:
|
||||
return ctx.zero
|
||||
if ctx._is_complex_type(x):
|
||||
if abs(x.imag) < min(tol, absx*tol):
|
||||
return x.real
|
||||
if abs(x.real) < min(tol, absx*tol):
|
||||
return ctx.mpc(0, x.imag)
|
||||
except TypeError:
|
||||
if isinstance(x, ctx.matrix):
|
||||
return x.apply(lambda a: ctx.chop(a, tol))
|
||||
if hasattr(x, "__iter__"):
|
||||
return [ctx.chop(a, tol) for a in x]
|
||||
return x
|
||||
|
||||
def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
|
||||
r"""
|
||||
Determine whether the difference between `s` and `t` is smaller
|
||||
than a given epsilon, either relatively or absolutely.
|
||||
|
||||
Both a maximum relative difference and a maximum difference
|
||||
('epsilons') may be specified. The absolute difference is
|
||||
defined as `|s-t|` and the relative difference is defined
|
||||
as `|s-t|/\max(|s|, |t|)`.
|
||||
|
||||
If only one epsilon is given, both are set to the same value.
|
||||
If none is given, both epsilons are set to `2^{-p+m}` where
|
||||
`p` is the current working precision and `m` is a small
|
||||
integer. The default setting typically allows :func:`almosteq`
|
||||
to be used to check for mathematical equality
|
||||
in the presence of small rounding errors.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15
|
||||
>>> almosteq(3.141592653589793, 3.141592653589790)
|
||||
True
|
||||
>>> almosteq(3.141592653589793, 3.141592653589700)
|
||||
False
|
||||
>>> almosteq(3.141592653589793, 3.141592653589700, 1e-10)
|
||||
True
|
||||
>>> almosteq(1e-20, 2e-20)
|
||||
True
|
||||
>>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0)
|
||||
False
|
||||
|
||||
"""
|
||||
t = ctx.convert(t)
|
||||
if abs_eps is None and rel_eps is None:
|
||||
rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4)
|
||||
if abs_eps is None:
|
||||
abs_eps = rel_eps
|
||||
elif rel_eps is None:
|
||||
rel_eps = abs_eps
|
||||
diff = abs(s-t)
|
||||
if diff <= abs_eps:
|
||||
return True
|
||||
abss = abs(s)
|
||||
abst = abs(t)
|
||||
if abss < abst:
|
||||
err = diff/abst
|
||||
else:
|
||||
err = diff/abss
|
||||
return err <= rel_eps
|
||||
|
||||
def arange(ctx, *args):
|
||||
r"""
|
||||
This is a generalized version of Python's :func:`range` function
|
||||
that accepts fractional endpoints and step sizes and
|
||||
returns a list of ``mpf`` instances. Like :func:`range`,
|
||||
:func:`arange` can be called with 1, 2 or 3 arguments:
|
||||
|
||||
``arange(b)``
|
||||
`[0, 1, 2, \ldots, x]`
|
||||
``arange(a, b)``
|
||||
`[a, a+1, a+2, \ldots, x]`
|
||||
``arange(a, b, h)``
|
||||
`[a, a+h, a+h, \ldots, x]`
|
||||
|
||||
where `b-1 \le x < b` (in the third case, `b-h \le x < b`).
|
||||
|
||||
Like Python's :func:`range`, the endpoint is not included. To
|
||||
produce ranges where the endpoint is included, :func:`linspace`
|
||||
is more convenient.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> arange(4)
|
||||
[mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')]
|
||||
>>> arange(1, 2, 0.25)
|
||||
[mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')]
|
||||
>>> arange(1, -1, -0.75)
|
||||
[mpf('1.0'), mpf('0.25'), mpf('-0.5')]
|
||||
|
||||
"""
|
||||
if not len(args) <= 3:
|
||||
raise TypeError('arange expected at most 3 arguments, got %i'
|
||||
% len(args))
|
||||
if not len(args) >= 1:
|
||||
raise TypeError('arange expected at least 1 argument, got %i'
|
||||
% len(args))
|
||||
# set default
|
||||
a = 0
|
||||
dt = 1
|
||||
# interpret arguments
|
||||
if len(args) == 1:
|
||||
b = args[0]
|
||||
elif len(args) >= 2:
|
||||
a = args[0]
|
||||
b = args[1]
|
||||
if len(args) == 3:
|
||||
dt = args[2]
|
||||
a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt)
|
||||
assert a + dt != a, 'dt is too small and would cause an infinite loop'
|
||||
# adapt code for sign of dt
|
||||
if a > b:
|
||||
if dt > 0:
|
||||
return []
|
||||
op = gt
|
||||
else:
|
||||
if dt < 0:
|
||||
return []
|
||||
op = lt
|
||||
# create list
|
||||
result = []
|
||||
i = 0
|
||||
t = a
|
||||
while 1:
|
||||
t = a + dt*i
|
||||
i += 1
|
||||
if op(t, b):
|
||||
result.append(t)
|
||||
else:
|
||||
break
|
||||
return result
|
||||
|
||||
def linspace(ctx, *args, **kwargs):
|
||||
"""
|
||||
``linspace(a, b, n)`` returns a list of `n` evenly spaced
|
||||
samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)``
|
||||
is also valid.
|
||||
|
||||
This function is often more convenient than :func:`arange`
|
||||
for partitioning an interval into subintervals, since
|
||||
the endpoint is included::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> linspace(1, 4, 4)
|
||||
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
|
||||
>>> linspace(mpi(1,4), 4)
|
||||
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
|
||||
|
||||
You may also provide the keyword argument ``endpoint=False``::
|
||||
|
||||
>>> linspace(1, 4, 4, endpoint=False)
|
||||
[mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')]
|
||||
|
||||
"""
|
||||
if len(args) == 3:
|
||||
a = ctx.mpf(args[0])
|
||||
b = ctx.mpf(args[1])
|
||||
n = int(args[2])
|
||||
elif len(args) == 2:
|
||||
assert hasattr(args[0], '_mpi_')
|
||||
a = args[0].a
|
||||
b = args[0].b
|
||||
n = int(args[1])
|
||||
else:
|
||||
raise TypeError('linspace expected 2 or 3 arguments, got %i' \
|
||||
% len(args))
|
||||
if n < 1:
|
||||
raise ValueError('n must be greater than 0')
|
||||
if not 'endpoint' in kwargs or kwargs['endpoint']:
|
||||
if n == 1:
|
||||
return [ctx.mpf(a)]
|
||||
step = (b - a) / ctx.mpf(n - 1)
|
||||
y = [i*step + a for i in xrange(n)]
|
||||
y[-1] = b
|
||||
else:
|
||||
step = (b - a) / ctx.mpf(n)
|
||||
y = [i*step + a for i in xrange(n)]
|
||||
return y
|
||||
|
||||
def cos_sin(ctx, z, **kwargs):
|
||||
return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs)
|
||||
|
||||
def _default_hyper_maxprec(ctx, p):
|
||||
return int(1000 * p**0.25 + 4*p)
|
||||
|
||||
_gcd = staticmethod(libmp.gcd)
|
||||
list_primes = staticmethod(libmp.list_primes)
|
||||
bernfrac = staticmethod(libmp.bernfrac)
|
||||
moebius = staticmethod(libmp.moebius)
|
||||
_ifac = staticmethod(libmp.ifac)
|
||||
_eulernum = staticmethod(libmp.eulernum)
|
||||
|
||||
def sum_accurately(ctx, terms, check_step=1):
|
||||
prec = ctx.prec
|
||||
try:
|
||||
extraprec = 10
|
||||
while 1:
|
||||
ctx.prec = prec + extraprec + 5
|
||||
max_mag = ctx.ninf
|
||||
s = ctx.zero
|
||||
k = 0
|
||||
for term in terms():
|
||||
s += term
|
||||
if (not k % check_step) and term:
|
||||
term_mag = ctx.mag(term)
|
||||
max_mag = max(max_mag, term_mag)
|
||||
sum_mag = ctx.mag(s)
|
||||
if sum_mag - term_mag > ctx.prec:
|
||||
break
|
||||
k += 1
|
||||
cancellation = max_mag - sum_mag
|
||||
if cancellation != cancellation:
|
||||
break
|
||||
if cancellation < extraprec or ctx._fixed_precision:
|
||||
break
|
||||
extraprec += min(ctx.prec, cancellation)
|
||||
return s
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
|
||||
def power(ctx, x, y):
|
||||
return ctx.convert(x) ** ctx.convert(y)
|
||||
|
||||
def _zeta_int(ctx, n):
|
||||
return ctx.zeta(n)
|
||||
|
|
@ -1,278 +0,0 @@
|
|||
from ctx_base import StandardBaseContext
|
||||
|
||||
import math
|
||||
import cmath
|
||||
import math2
|
||||
|
||||
import function_docs
|
||||
|
||||
from libmp import mpf_bernoulli, to_float, int_types
|
||||
import libmp
|
||||
|
||||
class FPContext(StandardBaseContext):
|
||||
"""
|
||||
Context for fast low-precision arithmetic (53-bit precision, giving at most
|
||||
about 15-digit accuracy), using Python's builtin float and complex.
|
||||
"""
|
||||
|
||||
def __init__(ctx):
|
||||
StandardBaseContext.__init__(ctx)
|
||||
|
||||
# Override SpecialFunctions implementation
|
||||
ctx.loggamma = math2.loggamma
|
||||
ctx._bernoulli_cache = {}
|
||||
ctx.pretty = False
|
||||
|
||||
ctx._init_aliases()
|
||||
|
||||
_mpq = lambda cls, x: float(x[0])/x[1]
|
||||
|
||||
NoConvergence = libmp.NoConvergence
|
||||
|
||||
def _get_prec(ctx): return 53
|
||||
def _set_prec(ctx, p): return
|
||||
def _get_dps(ctx): return 15
|
||||
def _set_dps(ctx, p): return
|
||||
|
||||
_fixed_precision = True
|
||||
|
||||
prec = property(_get_prec, _set_prec)
|
||||
dps = property(_get_dps, _set_dps)
|
||||
|
||||
zero = 0.0
|
||||
one = 1.0
|
||||
eps = math2.EPS
|
||||
inf = math2.INF
|
||||
ninf = math2.NINF
|
||||
nan = math2.NAN
|
||||
j = 1j
|
||||
|
||||
# Called by SpecialFunctions.__init__()
|
||||
@classmethod
|
||||
def _wrap_specfun(cls, name, f, wrap):
|
||||
if wrap:
|
||||
def f_wrapped(ctx, *args, **kwargs):
|
||||
convert = ctx.convert
|
||||
args = [convert(a) for a in args]
|
||||
return f(ctx, *args, **kwargs)
|
||||
else:
|
||||
f_wrapped = f
|
||||
f_wrapped.__doc__ = function_docs.__dict__.get(name, "<no doc>")
|
||||
setattr(cls, name, f_wrapped)
|
||||
|
||||
def bernoulli(ctx, n):
|
||||
cache = ctx._bernoulli_cache
|
||||
if n in cache:
|
||||
return cache[n]
|
||||
cache[n] = to_float(mpf_bernoulli(n, 53, 'n'), strict=True)
|
||||
return cache[n]
|
||||
|
||||
pi = math2.pi
|
||||
e = math2.e
|
||||
euler = math2.euler
|
||||
sqrt2 = 1.4142135623730950488
|
||||
sqrt5 = 2.2360679774997896964
|
||||
phi = 1.6180339887498948482
|
||||
ln2 = 0.69314718055994530942
|
||||
ln10 = 2.302585092994045684
|
||||
euler = 0.57721566490153286061
|
||||
catalan = 0.91596559417721901505
|
||||
khinchin = 2.6854520010653064453
|
||||
apery = 1.2020569031595942854
|
||||
|
||||
absmin = absmax = abs
|
||||
|
||||
def _as_points(ctx, x):
|
||||
return x
|
||||
|
||||
def fneg(ctx, x, **kwargs):
|
||||
return -ctx.convert(x)
|
||||
|
||||
def fadd(ctx, x, y, **kwargs):
|
||||
return ctx.convert(x)+ctx.convert(y)
|
||||
|
||||
def fsub(ctx, x, y, **kwargs):
|
||||
return ctx.convert(x)-ctx.convert(y)
|
||||
|
||||
def fmul(ctx, x, y, **kwargs):
|
||||
return ctx.convert(x)*ctx.convert(y)
|
||||
|
||||
def fdiv(ctx, x, y, **kwargs):
|
||||
return ctx.convert(x)/ctx.convert(y)
|
||||
|
||||
def fsum(ctx, args, absolute=False, squared=False):
|
||||
if absolute:
|
||||
if squared:
|
||||
return sum((abs(x)**2 for x in args), ctx.zero)
|
||||
return sum((abs(x) for x in args), ctx.zero)
|
||||
if squared:
|
||||
return sum((x**2 for x in args), ctx.zero)
|
||||
return sum(args, ctx.zero)
|
||||
|
||||
def fdot(ctx, xs, ys=None):
|
||||
if ys is not None:
|
||||
xs = zip(xs, ys)
|
||||
return sum((x*y for (x,y) in xs), ctx.zero)
|
||||
|
||||
def is_special(ctx, x):
|
||||
return x - x != 0.0
|
||||
|
||||
def isnan(ctx, x):
|
||||
return x != x
|
||||
|
||||
def isinf(ctx, x):
|
||||
return abs(x) == math2.INF
|
||||
|
||||
def isnpint(ctx, x):
|
||||
if type(x) is complex:
|
||||
if x.imag:
|
||||
return False
|
||||
x = x.real
|
||||
return x <= 0.0 and round(x) == x
|
||||
|
||||
mpf = float
|
||||
mpc = complex
|
||||
|
||||
def convert(ctx, x):
|
||||
try:
|
||||
return float(x)
|
||||
except:
|
||||
return complex(x)
|
||||
|
||||
power = staticmethod(math2.pow)
|
||||
sqrt = staticmethod(math2.sqrt)
|
||||
exp = staticmethod(math2.exp)
|
||||
ln = log = staticmethod(math2.log)
|
||||
cos = staticmethod(math2.cos)
|
||||
sin = staticmethod(math2.sin)
|
||||
tan = staticmethod(math2.tan)
|
||||
cos_sin = staticmethod(math2.cos_sin)
|
||||
acos = staticmethod(math2.acos)
|
||||
asin = staticmethod(math2.asin)
|
||||
atan = staticmethod(math2.atan)
|
||||
cosh = staticmethod(math2.cosh)
|
||||
sinh = staticmethod(math2.sinh)
|
||||
tanh = staticmethod(math2.tanh)
|
||||
gamma = staticmethod(math2.gamma)
|
||||
fac = factorial = staticmethod(math2.factorial)
|
||||
floor = staticmethod(math2.floor)
|
||||
ceil = staticmethod(math2.ceil)
|
||||
cospi = staticmethod(math2.cospi)
|
||||
sinpi = staticmethod(math2.sinpi)
|
||||
cbrt = staticmethod(math2.cbrt)
|
||||
_nthroot = staticmethod(math2.nthroot)
|
||||
_ei = staticmethod(math2.ei)
|
||||
_e1 = staticmethod(math2.e1)
|
||||
_zeta = _zeta_int = staticmethod(math2.zeta)
|
||||
|
||||
# XXX: math2
|
||||
def arg(ctx, z):
|
||||
z = complex(z)
|
||||
return math.atan2(z.imag, z.real)
|
||||
|
||||
def expj(ctx, x):
|
||||
return ctx.exp(ctx.j*x)
|
||||
|
||||
def expjpi(ctx, x):
|
||||
return ctx.exp(ctx.j*ctx.pi*x)
|
||||
|
||||
ldexp = math.ldexp
|
||||
frexp = math.frexp
|
||||
|
||||
def mag(ctx, z):
|
||||
if z:
|
||||
return ctx.frexp(abs(z))[1]
|
||||
return ctx.ninf
|
||||
|
||||
def isint(ctx, z):
|
||||
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
|
||||
if z.imag:
|
||||
return False
|
||||
z = z.real
|
||||
try:
|
||||
return z == int(z)
|
||||
except:
|
||||
return False
|
||||
|
||||
def nint_distance(ctx, z):
|
||||
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
|
||||
n = round(z.real)
|
||||
else:
|
||||
n = round(z)
|
||||
if n == z:
|
||||
return n, ctx.ninf
|
||||
return n, ctx.mag(abs(z-n))
|
||||
|
||||
def _convert_param(ctx, z):
|
||||
if type(z) is tuple:
|
||||
p, q = z
|
||||
return ctx.mpf(p) / q, 'R'
|
||||
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
|
||||
intz = int(z.real)
|
||||
else:
|
||||
intz = int(z)
|
||||
if z == intz:
|
||||
return intz, 'Z'
|
||||
return z, 'R'
|
||||
|
||||
def _is_real_type(ctx, z):
|
||||
return isinstance(z, float) or isinstance(z, int_types)
|
||||
|
||||
def _is_complex_type(ctx, z):
|
||||
return isinstance(z, complex)
|
||||
|
||||
def hypsum(ctx, p, q, types, coeffs, z, maxterms=6000, **kwargs):
|
||||
coeffs = list(coeffs)
|
||||
num = range(p)
|
||||
den = range(p,p+q)
|
||||
tol = ctx.eps
|
||||
s = t = 1.0
|
||||
k = 0
|
||||
while 1:
|
||||
for i in num: t *= (coeffs[i]+k)
|
||||
for i in den: t /= (coeffs[i]+k)
|
||||
k += 1; t /= k; t *= z; s += t
|
||||
if abs(t) < tol:
|
||||
return s
|
||||
if k > maxterms:
|
||||
raise ctx.NoConvergence
|
||||
|
||||
def atan2(ctx, x, y):
|
||||
return math.atan2(x, y)
|
||||
|
||||
def psi(ctx, m, z):
|
||||
m = int(m)
|
||||
if m == 0:
|
||||
return ctx.digamma(z)
|
||||
return (-1)**(m+1) * ctx.fac(m) * ctx.zeta(m+1, z)
|
||||
|
||||
digamma = staticmethod(math2.digamma)
|
||||
|
||||
def harmonic(ctx, x):
|
||||
x = ctx.convert(x)
|
||||
if x == 0 or x == 1:
|
||||
return x
|
||||
return ctx.digamma(x+1) + ctx.euler
|
||||
|
||||
nstr = str
|
||||
|
||||
def to_fixed(ctx, x, prec):
|
||||
return int(math.ldexp(x, prec))
|
||||
|
||||
def rand(ctx):
|
||||
import random
|
||||
return random.random()
|
||||
|
||||
_erf = staticmethod(math2.erf)
|
||||
_erfc = staticmethod(math2.erfc)
|
||||
|
||||
def sum_accurately(ctx, terms, check_step=1):
|
||||
s = ctx.zero
|
||||
k = 0
|
||||
for term in terms():
|
||||
s += term
|
||||
if (not k % check_step) and term:
|
||||
if abs(term) <= 1e-18*abs(s):
|
||||
break
|
||||
k += 1
|
||||
return s
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,986 +0,0 @@
|
|||
#from ctx_base import StandardBaseContext
|
||||
|
||||
from libmp import (MPZ, MPZ_ZERO, MPZ_ONE, int_types, repr_dps,
|
||||
round_floor, round_ceiling, dps_to_prec, round_nearest, prec_to_dps,
|
||||
ComplexResult, to_pickable, from_pickable, normalize,
|
||||
from_int, from_float, from_str, to_int, to_float, to_str,
|
||||
from_rational, from_man_exp,
|
||||
fone, fzero, finf, fninf, fnan,
|
||||
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul, mpf_mul_int,
|
||||
mpf_div, mpf_rdiv_int, mpf_pow_int, mpf_mod,
|
||||
mpf_eq, mpf_cmp, mpf_lt, mpf_gt, mpf_le, mpf_ge,
|
||||
mpf_hash, mpf_rand,
|
||||
mpf_sum,
|
||||
bitcount, to_fixed,
|
||||
mpc_to_str,
|
||||
mpc_to_complex, mpc_hash, mpc_pos, mpc_is_nonzero, mpc_neg, mpc_conjugate,
|
||||
mpc_abs, mpc_add, mpc_add_mpf, mpc_sub, mpc_sub_mpf, mpc_mul, mpc_mul_mpf,
|
||||
mpc_mul_int, mpc_div, mpc_div_mpf, mpc_pow, mpc_pow_mpf, mpc_pow_int,
|
||||
mpc_mpf_div,
|
||||
mpf_pow,
|
||||
mpi_mid, mpi_delta, mpi_str,
|
||||
mpi_abs, mpi_pos, mpi_neg, mpi_add, mpi_sub,
|
||||
mpi_mul, mpi_div, mpi_pow_int, mpi_pow,
|
||||
mpf_pi, mpf_degree, mpf_e, mpf_phi, mpf_ln2, mpf_ln10,
|
||||
mpf_euler, mpf_catalan, mpf_apery, mpf_khinchin,
|
||||
mpf_glaisher, mpf_twinprime, mpf_mertens,
|
||||
int_types)
|
||||
|
||||
import rational
|
||||
import function_docs
|
||||
|
||||
new = object.__new__
|
||||
|
||||
class mpnumeric(object):
|
||||
"""Base class for mpf and mpc."""
|
||||
__slots__ = []
|
||||
def __new__(cls, val):
|
||||
raise NotImplementedError
|
||||
|
||||
class _mpf(mpnumeric):
|
||||
"""
|
||||
An mpf instance holds a real-valued floating-point number. mpf:s
|
||||
work analogously to Python floats, but support arbitrary-precision
|
||||
arithmetic.
|
||||
"""
|
||||
__slots__ = ['_mpf_']
|
||||
|
||||
def __new__(cls, val=fzero, **kwargs):
|
||||
"""A new mpf can be created from a Python float, an int, a
|
||||
or a decimal string representing a number in floating-point
|
||||
format."""
|
||||
prec, rounding = cls.context._prec_rounding
|
||||
if kwargs:
|
||||
prec = kwargs.get('prec', prec)
|
||||
if 'dps' in kwargs:
|
||||
prec = dps_to_prec(kwargs['dps'])
|
||||
rounding = kwargs.get('rounding', rounding)
|
||||
if type(val) is cls:
|
||||
sign, man, exp, bc = val._mpf_
|
||||
if (not man) and exp:
|
||||
return val
|
||||
v = new(cls)
|
||||
v._mpf_ = normalize(sign, man, exp, bc, prec, rounding)
|
||||
return v
|
||||
elif type(val) is tuple:
|
||||
if len(val) == 2:
|
||||
v = new(cls)
|
||||
v._mpf_ = from_man_exp(val[0], val[1], prec, rounding)
|
||||
return v
|
||||
if len(val) == 4:
|
||||
sign, man, exp, bc = val
|
||||
v = new(cls)
|
||||
v._mpf_ = normalize(sign, MPZ(man), exp, bc, prec, rounding)
|
||||
return v
|
||||
raise ValueError
|
||||
else:
|
||||
v = new(cls)
|
||||
v._mpf_ = mpf_pos(cls.mpf_convert_arg(val, prec, rounding), prec, rounding)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def mpf_convert_arg(cls, x, prec, rounding):
|
||||
if isinstance(x, int_types): return from_int(x)
|
||||
if isinstance(x, float): return from_float(x)
|
||||
if isinstance(x, basestring): return from_str(x, prec, rounding)
|
||||
if isinstance(x, cls.context.constant): return x.func(prec, rounding)
|
||||
if hasattr(x, '_mpf_'): return x._mpf_
|
||||
if hasattr(x, '_mpmath_'):
|
||||
t = cls.context.convert(x._mpmath_(prec, rounding))
|
||||
if hasattr(t, '_mpf_'):
|
||||
return t._mpf_
|
||||
raise TypeError("cannot create mpf from " + repr(x))
|
||||
|
||||
@classmethod
|
||||
def mpf_convert_rhs(cls, x):
|
||||
if isinstance(x, int_types): return from_int(x)
|
||||
if isinstance(x, float): return from_float(x)
|
||||
if isinstance(x, complex_types): return cls.context.mpc(x)
|
||||
if isinstance(x, rational.mpq):
|
||||
p, q = x
|
||||
return from_rational(p, q, cls.context.prec)
|
||||
if hasattr(x, '_mpf_'): return x._mpf_
|
||||
if hasattr(x, '_mpmath_'):
|
||||
t = cls.context.convert(x._mpmath_(*cls.context._prec_rounding))
|
||||
if hasattr(t, '_mpf_'):
|
||||
return t._mpf_
|
||||
return t
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def mpf_convert_lhs(cls, x):
|
||||
x = cls.mpf_convert_rhs(x)
|
||||
if type(x) is tuple:
|
||||
return cls.context.make_mpf(x)
|
||||
return x
|
||||
|
||||
man_exp = property(lambda self: self._mpf_[1:3])
|
||||
man = property(lambda self: self._mpf_[1])
|
||||
exp = property(lambda self: self._mpf_[2])
|
||||
bc = property(lambda self: self._mpf_[3])
|
||||
|
||||
real = property(lambda self: self)
|
||||
imag = property(lambda self: self.context.zero)
|
||||
|
||||
conjugate = lambda self: self
|
||||
|
||||
def __getstate__(self): return to_pickable(self._mpf_)
|
||||
def __setstate__(self, val): self._mpf_ = from_pickable(val)
|
||||
|
||||
def __repr__(s):
|
||||
if s.context.pretty:
|
||||
return str(s)
|
||||
return "mpf('%s')" % to_str(s._mpf_, s.context._repr_digits)
|
||||
|
||||
def __str__(s): return to_str(s._mpf_, s.context._str_digits)
|
||||
def __hash__(s): return mpf_hash(s._mpf_)
|
||||
def __int__(s): return int(to_int(s._mpf_))
|
||||
def __long__(s): return long(to_int(s._mpf_))
|
||||
def __float__(s): return to_float(s._mpf_)
|
||||
def __complex__(s): return complex(float(s))
|
||||
def __nonzero__(s): return s._mpf_ != fzero
|
||||
|
||||
def __abs__(s):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
v = new(cls)
|
||||
v._mpf_ = mpf_abs(s._mpf_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __pos__(s):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
v = new(cls)
|
||||
v._mpf_ = mpf_pos(s._mpf_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __neg__(s):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
v = new(cls)
|
||||
v._mpf_ = mpf_neg(s._mpf_, prec, rounding)
|
||||
return v
|
||||
|
||||
def _cmp(s, t, func):
|
||||
if hasattr(t, '_mpf_'):
|
||||
t = t._mpf_
|
||||
else:
|
||||
t = s.mpf_convert_rhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return func(s._mpf_, t)
|
||||
|
||||
def __cmp__(s, t): return s._cmp(t, mpf_cmp)
|
||||
def __lt__(s, t): return s._cmp(t, mpf_lt)
|
||||
def __gt__(s, t): return s._cmp(t, mpf_gt)
|
||||
def __le__(s, t): return s._cmp(t, mpf_le)
|
||||
def __ge__(s, t): return s._cmp(t, mpf_ge)
|
||||
|
||||
def __ne__(s, t):
|
||||
v = s.__eq__(t)
|
||||
if v is NotImplemented:
|
||||
return v
|
||||
return not v
|
||||
|
||||
def __rsub__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if type(t) in int_types:
|
||||
v = new(cls)
|
||||
v._mpf_ = mpf_sub(from_int(t), s._mpf_, prec, rounding)
|
||||
return v
|
||||
t = s.mpf_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t - s
|
||||
|
||||
def __rdiv__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if isinstance(t, int_types):
|
||||
v = new(cls)
|
||||
v._mpf_ = mpf_rdiv_int(t, s._mpf_, prec, rounding)
|
||||
return v
|
||||
t = s.mpf_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t / s
|
||||
|
||||
def __rpow__(s, t):
|
||||
t = s.mpf_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t ** s
|
||||
|
||||
def __rmod__(s, t):
|
||||
t = s.mpf_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t % s
|
||||
|
||||
def sqrt(s):
|
||||
return s.context.sqrt(s)
|
||||
|
||||
def ae(s, t, rel_eps=None, abs_eps=None):
|
||||
return s.context.almosteq(s, t, rel_eps, abs_eps)
|
||||
|
||||
def to_fixed(self, prec):
|
||||
return to_fixed(self._mpf_, prec)
|
||||
|
||||
|
||||
mpf_binary_op = """
|
||||
def %NAME%(self, other):
|
||||
mpf, new, (prec, rounding) = self._ctxdata
|
||||
sval = self._mpf_
|
||||
if hasattr(other, '_mpf_'):
|
||||
tval = other._mpf_
|
||||
%WITH_MPF%
|
||||
ttype = type(other)
|
||||
if ttype in int_types:
|
||||
%WITH_INT%
|
||||
elif ttype is float:
|
||||
tval = from_float(other)
|
||||
%WITH_MPF%
|
||||
elif hasattr(other, '_mpc_'):
|
||||
tval = other._mpc_
|
||||
mpc = type(other)
|
||||
%WITH_MPC%
|
||||
elif ttype is complex:
|
||||
tval = from_float(other.real), from_float(other.imag)
|
||||
mpc = self.context.mpc
|
||||
%WITH_MPC%
|
||||
if isinstance(other, mpnumeric):
|
||||
return NotImplemented
|
||||
try:
|
||||
other = mpf.context.convert(other, strings=False)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
return self.%NAME%(other)
|
||||
"""
|
||||
|
||||
return_mpf = "; obj = new(mpf); obj._mpf_ = val; return obj"
|
||||
return_mpc = "; obj = new(mpc); obj._mpc_ = val; return obj"
|
||||
|
||||
mpf_pow_same = """
|
||||
try:
|
||||
val = mpf_pow(sval, tval, prec, rounding) %s
|
||||
except ComplexResult:
|
||||
if mpf.context.trap_complex:
|
||||
raise
|
||||
mpc = mpf.context.mpc
|
||||
val = mpc_pow((sval, fzero), (tval, fzero), prec, rounding) %s
|
||||
""" % (return_mpf, return_mpc)
|
||||
|
||||
def binary_op(name, with_mpf='', with_int='', with_mpc=''):
|
||||
code = mpf_binary_op
|
||||
code = code.replace("%WITH_INT%", with_int)
|
||||
code = code.replace("%WITH_MPC%", with_mpc)
|
||||
code = code.replace("%WITH_MPF%", with_mpf)
|
||||
code = code.replace("%NAME%", name)
|
||||
np = {}
|
||||
exec code in globals(), np
|
||||
return np[name]
|
||||
|
||||
_mpf.__eq__ = binary_op('__eq__',
|
||||
'return mpf_eq(sval, tval)',
|
||||
'return mpf_eq(sval, from_int(other))',
|
||||
'return (tval[1] == fzero) and mpf_eq(tval[0], sval)')
|
||||
|
||||
_mpf.__add__ = binary_op('__add__',
|
||||
'val = mpf_add(sval, tval, prec, rounding)' + return_mpf,
|
||||
'val = mpf_add(sval, from_int(other), prec, rounding)' + return_mpf,
|
||||
'val = mpc_add_mpf(tval, sval, prec, rounding)' + return_mpc)
|
||||
|
||||
_mpf.__sub__ = binary_op('__sub__',
|
||||
'val = mpf_sub(sval, tval, prec, rounding)' + return_mpf,
|
||||
'val = mpf_sub(sval, from_int(other), prec, rounding)' + return_mpf,
|
||||
'val = mpc_sub((sval, fzero), tval, prec, rounding)' + return_mpc)
|
||||
|
||||
_mpf.__mul__ = binary_op('__mul__',
|
||||
'val = mpf_mul(sval, tval, prec, rounding)' + return_mpf,
|
||||
'val = mpf_mul_int(sval, other, prec, rounding)' + return_mpf,
|
||||
'val = mpc_mul_mpf(tval, sval, prec, rounding)' + return_mpc)
|
||||
|
||||
_mpf.__div__ = binary_op('__div__',
|
||||
'val = mpf_div(sval, tval, prec, rounding)' + return_mpf,
|
||||
'val = mpf_div(sval, from_int(other), prec, rounding)' + return_mpf,
|
||||
'val = mpc_mpf_div(sval, tval, prec, rounding)' + return_mpc)
|
||||
|
||||
_mpf.__mod__ = binary_op('__mod__',
|
||||
'val = mpf_mod(sval, tval, prec, rounding)' + return_mpf,
|
||||
'val = mpf_mod(sval, from_int(other), prec, rounding)' + return_mpf,
|
||||
'raise NotImplementedError("complex modulo")')
|
||||
|
||||
_mpf.__pow__ = binary_op('__pow__',
|
||||
mpf_pow_same,
|
||||
'val = mpf_pow_int(sval, other, prec, rounding)' + return_mpf,
|
||||
'val = mpc_pow((sval, fzero), tval, prec, rounding)' + return_mpc)
|
||||
|
||||
_mpf.__radd__ = _mpf.__add__
|
||||
_mpf.__rmul__ = _mpf.__mul__
|
||||
_mpf.__truediv__ = _mpf.__div__
|
||||
_mpf.__rtruediv__ = _mpf.__rdiv__
|
||||
|
||||
|
||||
class _constant(_mpf):
|
||||
"""Represents a mathematical constant with dynamic precision.
|
||||
When printed or used in an arithmetic operation, a constant
|
||||
is converted to a regular mpf at the working precision. A
|
||||
regular mpf can also be obtained using the operation +x."""
|
||||
|
||||
def __new__(cls, func, name, docname=''):
|
||||
a = object.__new__(cls)
|
||||
a.name = name
|
||||
a.func = func
|
||||
a.__doc__ = getattr(function_docs, docname, '')
|
||||
return a
|
||||
|
||||
def __call__(self, prec=None, dps=None, rounding=None):
|
||||
prec2, rounding2 = self.context._prec_rounding
|
||||
if not prec: prec = prec2
|
||||
if not rounding: rounding = rounding2
|
||||
if dps: prec = dps_to_prec(dps)
|
||||
return self.context.make_mpf(self.func(prec, rounding))
|
||||
|
||||
@property
|
||||
def _mpf_(self):
|
||||
prec, rounding = self.context._prec_rounding
|
||||
return self.func(prec, rounding)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s~>" % (self.name, self.context.nstr(self))
|
||||
|
||||
|
||||
class _mpc(mpnumeric):
|
||||
"""
|
||||
An mpc represents a complex number using a pair of mpf:s (one
|
||||
for the real part and another for the imaginary part.) The mpc
|
||||
class behaves fairly similarly to Python's complex type.
|
||||
"""
|
||||
|
||||
__slots__ = ['_mpc_']
|
||||
|
||||
def __new__(cls, real=0, imag=0):
|
||||
s = object.__new__(cls)
|
||||
if isinstance(real, complex_types):
|
||||
real, imag = real.real, real.imag
|
||||
elif hasattr(real, '_mpc_'):
|
||||
s._mpc_ = real._mpc_
|
||||
return s
|
||||
real = cls.context.mpf(real)
|
||||
imag = cls.context.mpf(imag)
|
||||
s._mpc_ = (real._mpf_, imag._mpf_)
|
||||
return s
|
||||
|
||||
real = property(lambda self: self.context.make_mpf(self._mpc_[0]))
|
||||
imag = property(lambda self: self.context.make_mpf(self._mpc_[1]))
|
||||
|
||||
def __getstate__(self):
|
||||
return to_pickable(self._mpc_[0]), to_pickable(self._mpc_[1])
|
||||
|
||||
def __setstate__(self, val):
|
||||
self._mpc_ = from_pickable(val[0]), from_pickable(val[1])
|
||||
|
||||
def __repr__(s):
|
||||
if s.context.pretty:
|
||||
return str(s)
|
||||
r = repr(s.real)[4:-1]
|
||||
i = repr(s.imag)[4:-1]
|
||||
return "%s(real=%s, imag=%s)" % (type(s).__name__, r, i)
|
||||
|
||||
def __str__(s):
|
||||
return "(%s)" % mpc_to_str(s._mpc_, s.context._str_digits)
|
||||
|
||||
def __complex__(s):
|
||||
return mpc_to_complex(s._mpc_)
|
||||
|
||||
def __pos__(s):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_pos(s._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __abs__(s):
|
||||
prec, rounding = s.context._prec_rounding
|
||||
v = new(s.context.mpf)
|
||||
v._mpf_ = mpc_abs(s._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __neg__(s):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_neg(s._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def conjugate(s):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_conjugate(s._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __nonzero__(s):
|
||||
return mpc_is_nonzero(s._mpc_)
|
||||
|
||||
def __hash__(s):
|
||||
return mpc_hash(s._mpc_)
|
||||
|
||||
@classmethod
|
||||
def mpc_convert_lhs(cls, x):
|
||||
try:
|
||||
y = cls.context.convert(x)
|
||||
return y
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(s, t):
|
||||
if not hasattr(t, '_mpc_'):
|
||||
if isinstance(t, str):
|
||||
return False
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return s.real == t.real and s.imag == t.imag
|
||||
|
||||
def __ne__(s, t):
|
||||
b = s.__eq__(t)
|
||||
if b is NotImplemented:
|
||||
return b
|
||||
return not b
|
||||
|
||||
def _compare(*args):
|
||||
raise TypeError("no ordering relation is defined for complex numbers")
|
||||
|
||||
__gt__ = _compare
|
||||
__le__ = _compare
|
||||
__gt__ = _compare
|
||||
__ge__ = _compare
|
||||
|
||||
def __add__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if not hasattr(t, '_mpc_'):
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
if hasattr(t, '_mpf_'):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_add_mpf(s._mpc_, t._mpf_, prec, rounding)
|
||||
return v
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_add(s._mpc_, t._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __sub__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if not hasattr(t, '_mpc_'):
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
if hasattr(t, '_mpf_'):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_sub_mpf(s._mpc_, t._mpf_, prec, rounding)
|
||||
return v
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_sub(s._mpc_, t._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __mul__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if not hasattr(t, '_mpc_'):
|
||||
if isinstance(t, int_types):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_mul_int(s._mpc_, t, prec, rounding)
|
||||
return v
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
if hasattr(t, '_mpf_'):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_mul_mpf(s._mpc_, t._mpf_, prec, rounding)
|
||||
return v
|
||||
t = s.mpc_convert_lhs(t)
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_mul(s._mpc_, t._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __div__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if not hasattr(t, '_mpc_'):
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
if hasattr(t, '_mpf_'):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_div_mpf(s._mpc_, t._mpf_, prec, rounding)
|
||||
return v
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_div(s._mpc_, t._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
def __pow__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if isinstance(t, int_types):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_pow_int(s._mpc_, t, prec, rounding)
|
||||
return v
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
v = new(cls)
|
||||
if hasattr(t, '_mpf_'):
|
||||
v._mpc_ = mpc_pow_mpf(s._mpc_, t._mpf_, prec, rounding)
|
||||
else:
|
||||
v._mpc_ = mpc_pow(s._mpc_, t._mpc_, prec, rounding)
|
||||
return v
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __rsub__(s, t):
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t - s
|
||||
|
||||
def __rmul__(s, t):
|
||||
cls, new, (prec, rounding) = s._ctxdata
|
||||
if isinstance(t, int_types):
|
||||
v = new(cls)
|
||||
v._mpc_ = mpc_mul_int(s._mpc_, t, prec, rounding)
|
||||
return v
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t * s
|
||||
|
||||
def __rdiv__(s, t):
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t / s
|
||||
|
||||
def __rpow__(s, t):
|
||||
t = s.mpc_convert_lhs(t)
|
||||
if t is NotImplemented:
|
||||
return t
|
||||
return t ** s
|
||||
|
||||
__truediv__ = __div__
|
||||
__rtruediv__ = __rdiv__
|
||||
|
||||
def ae(s, t, rel_eps=None, abs_eps=None):
|
||||
return s.context.almosteq(s, t, rel_eps, abs_eps)
|
||||
|
||||
|
||||
complex_types = (complex, _mpc)
|
||||
|
||||
|
||||
class PythonMPContext:
|
||||
|
||||
def __init__(ctx):
|
||||
ctx._prec_rounding = [53, round_nearest]
|
||||
ctx.mpf = type('mpf', (_mpf,), {})
|
||||
ctx.mpc = type('mpc', (_mpc,), {})
|
||||
ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
|
||||
ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding]
|
||||
ctx.mpf.context = ctx
|
||||
ctx.mpc.context = ctx
|
||||
ctx.constant = type('constant', (_constant,), {})
|
||||
ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
|
||||
ctx.constant.context = ctx
|
||||
|
||||
def make_mpf(ctx, v):
|
||||
a = new(ctx.mpf)
|
||||
a._mpf_ = v
|
||||
return a
|
||||
|
||||
def make_mpc(ctx, v):
|
||||
a = new(ctx.mpc)
|
||||
a._mpc_ = v
|
||||
return a
|
||||
|
||||
def default(ctx):
|
||||
ctx._prec = ctx._prec_rounding[0] = 53
|
||||
ctx._dps = 15
|
||||
ctx.trap_complex = False
|
||||
|
||||
def _set_prec(ctx, n):
|
||||
ctx._prec = ctx._prec_rounding[0] = max(1, int(n))
|
||||
ctx._dps = prec_to_dps(n)
|
||||
|
||||
def _set_dps(ctx, n):
|
||||
ctx._prec = ctx._prec_rounding[0] = dps_to_prec(n)
|
||||
ctx._dps = max(1, int(n))
|
||||
|
||||
prec = property(lambda ctx: ctx._prec, _set_prec)
|
||||
dps = property(lambda ctx: ctx._dps, _set_dps)
|
||||
|
||||
def convert(ctx, x, strings=True):
|
||||
"""
|
||||
Converts *x* to an ``mpf``, ``mpc`` or ``mpi``. If *x* is of type ``mpf``,
|
||||
``mpc``, ``int``, ``float``, ``complex``, the conversion
|
||||
will be performed losslessly.
|
||||
|
||||
If *x* is a string, the result will be rounded to the present
|
||||
working precision. Strings representing fractions or complex
|
||||
numbers are permitted.
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> mpmathify(3.5)
|
||||
mpf('3.5')
|
||||
>>> mpmathify('2.1')
|
||||
mpf('2.1000000000000001')
|
||||
>>> mpmathify('3/4')
|
||||
mpf('0.75')
|
||||
>>> mpmathify('2+3j')
|
||||
mpc(real='2.0', imag='3.0')
|
||||
|
||||
"""
|
||||
if type(x) in ctx.types: return x
|
||||
if isinstance(x, int_types): return ctx.make_mpf(from_int(x))
|
||||
if isinstance(x, float): return ctx.make_mpf(from_float(x))
|
||||
if isinstance(x, complex):
|
||||
return ctx.make_mpc((from_float(x.real), from_float(x.imag)))
|
||||
prec, rounding = ctx._prec_rounding
|
||||
if isinstance(x, rational.mpq):
|
||||
p, q = x
|
||||
return ctx.make_mpf(from_rational(p, q, prec))
|
||||
if strings and isinstance(x, basestring):
|
||||
try:
|
||||
_mpf_ = from_str(x, prec, rounding)
|
||||
return ctx.make_mpf(_mpf_)
|
||||
except ValueError:
|
||||
pass
|
||||
if hasattr(x, '_mpf_'): return ctx.make_mpf(x._mpf_)
|
||||
if hasattr(x, '_mpc_'): return ctx.make_mpc(x._mpc_)
|
||||
if hasattr(x, '_mpmath_'):
|
||||
return ctx.convert(x._mpmath_(prec, rounding))
|
||||
return ctx._convert_fallback(x, strings)
|
||||
|
||||
def isnan(ctx, x):
|
||||
"""
|
||||
For an ``mpf`` *x*, determines whether *x* is not-a-number (nan)::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> isnan(nan), isnan(3)
|
||||
(True, False)
|
||||
"""
|
||||
if not hasattr(x, '_mpf_'):
|
||||
return False
|
||||
return x._mpf_ == fnan
|
||||
|
||||
def isinf(ctx, x):
|
||||
"""
|
||||
For an ``mpf`` *x*, determines whether *x* is infinite::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> isinf(inf), isinf(-inf), isinf(3)
|
||||
(True, True, False)
|
||||
"""
|
||||
if not hasattr(x, '_mpf_'):
|
||||
return False
|
||||
return x._mpf_ in (finf, fninf)
|
||||
|
||||
def isint(ctx, x):
|
||||
"""
|
||||
For an ``mpf`` *x*, or any type that can be converted
|
||||
to ``mpf``, determines whether *x* is exactly
|
||||
integer-valued::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> isint(3), isint(mpf(3)), isint(3.2)
|
||||
(True, True, False)
|
||||
"""
|
||||
if isinstance(x, int_types):
|
||||
return True
|
||||
try:
|
||||
x = ctx.convert(x)
|
||||
except:
|
||||
return False
|
||||
if hasattr(x, '_mpf_'):
|
||||
if ctx.isnan(x) or ctx.isinf(x):
|
||||
return False
|
||||
return x == int(x)
|
||||
if isinstance(x, ctx.mpq):
|
||||
p, q = x
|
||||
return not (p % q)
|
||||
return False
|
||||
|
||||
def fsum(ctx, terms, absolute=False, squared=False):
|
||||
"""
|
||||
Calculates a sum containing a finite number of terms (for infinite
|
||||
series, see :func:`nsum`). The terms will be converted to
|
||||
mpmath numbers. For len(terms) > 2, this function is generally
|
||||
faster and produces more accurate results than the builtin
|
||||
Python function :func:`sum`.
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> fsum([1, 2, 0.5, 7])
|
||||
mpf('10.5')
|
||||
|
||||
With squared=True each term is squared, and with absolute=True
|
||||
the absolute value of each term is used.
|
||||
"""
|
||||
prec, rnd = ctx._prec_rounding
|
||||
real = []
|
||||
imag = []
|
||||
other = 0
|
||||
for term in terms:
|
||||
reval = imval = 0
|
||||
if hasattr(term, "_mpf_"):
|
||||
reval = term._mpf_
|
||||
elif hasattr(term, "_mpc_"):
|
||||
reval, imval = term._mpc_
|
||||
else:
|
||||
term = ctx.convert(term)
|
||||
if hasattr(term, "_mpf_"):
|
||||
reval = term._mpf_
|
||||
elif hasattr(term, "_mpc_"):
|
||||
reval, imval = term._mpc_
|
||||
else:
|
||||
if absolute: term = ctx.absmax(term)
|
||||
if squared: term = term**2
|
||||
other += term
|
||||
continue
|
||||
if imval:
|
||||
if squared:
|
||||
if absolute:
|
||||
real.append(mpf_mul(reval,reval))
|
||||
real.append(mpf_mul(imval,imval))
|
||||
else:
|
||||
reval, imval = mpc_pow_int((reval,imval),2,prec+10)
|
||||
real.append(reval)
|
||||
imag.append(imval)
|
||||
elif absolute:
|
||||
real.append(mpc_abs((reval,imval), prec))
|
||||
else:
|
||||
real.append(reval)
|
||||
imag.append(imval)
|
||||
else:
|
||||
if squared:
|
||||
reval = mpf_mul(reval, reval)
|
||||
elif absolute:
|
||||
reval = mpf_abs(reval)
|
||||
real.append(reval)
|
||||
s = mpf_sum(real, prec, rnd, absolute)
|
||||
if imag:
|
||||
s = ctx.make_mpc((s, mpf_sum(imag, prec, rnd)))
|
||||
else:
|
||||
s = ctx.make_mpf(s)
|
||||
if other is 0:
|
||||
return s
|
||||
else:
|
||||
return s + other
|
||||
|
||||
def fdot(ctx, A, B=None):
|
||||
r"""
|
||||
Computes the dot product of the iterables `A` and `B`,
|
||||
|
||||
.. math ::
|
||||
|
||||
\sum_{k=0} A_k B_k.
|
||||
|
||||
Alternatively, :func:`fdot` accepts a single iterable of pairs.
|
||||
In other words, ``fdot(A,B)`` and ``fdot(zip(A,B))`` are equivalent.
|
||||
|
||||
The elements are automatically converted to mpmath numbers.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> A = [2, 1.5, 3]
|
||||
>>> B = [1, -1, 2]
|
||||
>>> fdot(A, B)
|
||||
mpf('6.5')
|
||||
>>> zip(A, B)
|
||||
[(2, 1), (1.5, -1), (3, 2)]
|
||||
>>> fdot(_)
|
||||
mpf('6.5')
|
||||
|
||||
"""
|
||||
if B:
|
||||
A = zip(A, B)
|
||||
prec, rnd = ctx._prec_rounding
|
||||
real = []
|
||||
imag = []
|
||||
other = 0
|
||||
hasattr_ = hasattr
|
||||
types = (ctx.mpf, ctx.mpc)
|
||||
for a, b in A:
|
||||
if type(a) not in types: a = ctx.convert(a)
|
||||
if type(b) not in types: b = ctx.convert(b)
|
||||
a_real = hasattr_(a, "_mpf_")
|
||||
b_real = hasattr_(b, "_mpf_")
|
||||
if a_real and b_real:
|
||||
real.append(mpf_mul(a._mpf_, b._mpf_))
|
||||
continue
|
||||
a_complex = hasattr_(a, "_mpc_")
|
||||
b_complex = hasattr_(b, "_mpc_")
|
||||
if a_real and b_complex:
|
||||
aval = a._mpf_
|
||||
bre, bim = b._mpc_
|
||||
real.append(mpf_mul(aval, bre))
|
||||
imag.append(mpf_mul(aval, bim))
|
||||
elif b_real and a_complex:
|
||||
are, aim = a._mpc_
|
||||
bval = b._mpf_
|
||||
real.append(mpf_mul(are, bval))
|
||||
imag.append(mpf_mul(aim, bval))
|
||||
elif a_complex and b_complex:
|
||||
re, im = mpc_mul(a._mpc_, b._mpc_, prec+20)
|
||||
real.append(re)
|
||||
imag.append(im)
|
||||
else:
|
||||
other += a*b
|
||||
s = mpf_sum(real, prec, rnd)
|
||||
if imag:
|
||||
s = ctx.make_mpc((s, mpf_sum(imag, prec, rnd)))
|
||||
else:
|
||||
s = ctx.make_mpf(s)
|
||||
if other is 0:
|
||||
return s
|
||||
else:
|
||||
return s + other
|
||||
|
||||
def _wrap_libmp_function(ctx, mpf_f, mpc_f=None, mpi_f=None, doc="<no doc>"):
|
||||
"""
|
||||
Given a low-level mpf_ function, and optionally similar functions
|
||||
for mpc_ and mpi_, defines the function as a context method.
|
||||
|
||||
It is assumed that the return type is the same as that of
|
||||
the input; the exception is that propagation from mpf to mpc is possible
|
||||
by raising ComplexResult.
|
||||
|
||||
"""
|
||||
def f(x, **kwargs):
|
||||
if type(x) not in ctx.types:
|
||||
x = ctx.convert(x)
|
||||
prec, rounding = ctx._prec_rounding
|
||||
if kwargs:
|
||||
prec = kwargs.get('prec', prec)
|
||||
if 'dps' in kwargs:
|
||||
prec = dps_to_prec(kwargs['dps'])
|
||||
rounding = kwargs.get('rounding', rounding)
|
||||
if hasattr(x, '_mpf_'):
|
||||
try:
|
||||
return ctx.make_mpf(mpf_f(x._mpf_, prec, rounding))
|
||||
except ComplexResult:
|
||||
# Handle propagation to complex
|
||||
if ctx.trap_complex:
|
||||
raise
|
||||
return ctx.make_mpc(mpc_f((x._mpf_, fzero), prec, rounding))
|
||||
elif hasattr(x, '_mpc_'):
|
||||
return ctx.make_mpc(mpc_f(x._mpc_, prec, rounding))
|
||||
elif hasattr(x, '_mpi_'):
|
||||
if mpi_f:
|
||||
return ctx.make_mpi(mpi_f(x._mpi_, prec))
|
||||
raise NotImplementedError("%s of a %s" % (name, type(x)))
|
||||
name = mpf_f.__name__[4:]
|
||||
f.__doc__ = function_docs.__dict__.get(name, "Computes the %s of x" % doc)
|
||||
return f
|
||||
|
||||
# Called by SpecialFunctions.__init__()
|
||||
@classmethod
|
||||
def _wrap_specfun(cls, name, f, wrap):
|
||||
if wrap:
|
||||
def f_wrapped(ctx, *args, **kwargs):
|
||||
convert = ctx.convert
|
||||
args = [convert(a) for a in args]
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
retval = f(ctx, *args, **kwargs)
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return +retval
|
||||
else:
|
||||
f_wrapped = f
|
||||
f_wrapped.__doc__ = function_docs.__dict__.get(name, "<no doc>")
|
||||
setattr(cls, name, f_wrapped)
|
||||
|
||||
def _convert_param(ctx, x):
|
||||
if hasattr(x, "_mpc_"):
|
||||
v, im = x._mpc_
|
||||
if im != fzero:
|
||||
return x, 'C'
|
||||
elif hasattr(x, "_mpf_"):
|
||||
v = x._mpf_
|
||||
else:
|
||||
if type(x) in int_types:
|
||||
return int(x), 'Z'
|
||||
p = None
|
||||
if isinstance(x, tuple):
|
||||
p, q = x
|
||||
elif isinstance(x, basestring) and '/' in x:
|
||||
p, q = x.split('/')
|
||||
p = int(p)
|
||||
q = int(q)
|
||||
if p is not None:
|
||||
if not p % q:
|
||||
return p // q, 'Z'
|
||||
return ctx.mpq((p,q)), 'Q'
|
||||
x = ctx.convert(x)
|
||||
if hasattr(x, "_mpc_"):
|
||||
v, im = x._mpc_
|
||||
if im != fzero:
|
||||
return x, 'C'
|
||||
elif hasattr(x, "_mpf_"):
|
||||
v = x._mpf_
|
||||
else:
|
||||
return x, 'U'
|
||||
sign, man, exp, bc = v
|
||||
if man:
|
||||
if exp >= -4:
|
||||
if sign:
|
||||
man = -man
|
||||
if exp >= 0:
|
||||
return int(man) << exp, 'Z'
|
||||
if exp >= -4:
|
||||
p, q = int(man), (1<<(-exp))
|
||||
return ctx.mpq((p,q)), 'Q'
|
||||
x = ctx.make_mpf(v)
|
||||
return x, 'R'
|
||||
elif not exp:
|
||||
return 0, 'Z'
|
||||
else:
|
||||
return x, 'U'
|
||||
|
||||
def _mpf_mag(ctx, x):
|
||||
sign, man, exp, bc = x
|
||||
if man:
|
||||
return exp+bc
|
||||
if x == fzero:
|
||||
return ctx.ninf
|
||||
if x == finf or x == fninf:
|
||||
return ctx.inf
|
||||
return ctx.nan
|
||||
|
||||
def mag(ctx, x):
|
||||
"""
|
||||
Quick logarithmic magnitude estimate of a number.
|
||||
Returns an integer or infinity `m` such that `|x| <= 2^m`.
|
||||
It is not guaranteed that `m` is an optimal bound,
|
||||
but it will never be off by more than 2 (and probably not
|
||||
more than 1).
|
||||
"""
|
||||
if hasattr(x, "_mpf_"):
|
||||
return ctx._mpf_mag(x._mpf_)
|
||||
elif hasattr(x, "_mpc_"):
|
||||
r, i = x._mpc_
|
||||
if r == fzero:
|
||||
return ctx._mpf_mag(i)
|
||||
if i == fzero:
|
||||
return ctx._mpf_mag(r)
|
||||
return 1+max(ctx._mpf_mag(r), ctx._mpf_mag(i))
|
||||
elif isinstance(x, int_types):
|
||||
if x:
|
||||
return bitcount(abs(x))
|
||||
return ctx.ninf
|
||||
elif isinstance(x, rational.mpq):
|
||||
p, q = x
|
||||
if p:
|
||||
return 1 + bitcount(abs(p)) - bitcount(abs(q))
|
||||
return ctx.ninf
|
||||
else:
|
||||
x = ctx.convert(x)
|
||||
if hasattr(x, "_mpf_") or hasattr(x, "_mpc_"):
|
||||
return ctx.mag(x)
|
||||
else:
|
||||
raise TypeError("requires an mpf/mpc")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,7 +0,0 @@
|
|||
import functions
|
||||
# Hack to update methods
|
||||
import factorials
|
||||
import hypergeometric
|
||||
import elliptic
|
||||
import zeta
|
||||
import rszeta
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,196 +0,0 @@
|
|||
from functions import defun, defun_wrapped
|
||||
|
||||
@defun
|
||||
def gammaprod(ctx, a, b, _infsign=False):
|
||||
a = [ctx.convert(x) for x in a]
|
||||
b = [ctx.convert(x) for x in b]
|
||||
poles_num = []
|
||||
poles_den = []
|
||||
regular_num = []
|
||||
regular_den = []
|
||||
for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
|
||||
for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
|
||||
# One more pole in numerator or denominator gives 0 or inf
|
||||
if len(poles_num) < len(poles_den): return ctx.zero
|
||||
if len(poles_num) > len(poles_den):
|
||||
# Get correct sign of infinity for x+h, h -> 0 from above
|
||||
# XXX: hack, this should be done properly
|
||||
if _infsign:
|
||||
a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
|
||||
b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
|
||||
return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
|
||||
else:
|
||||
return ctx.inf
|
||||
# All poles cancel
|
||||
# lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
|
||||
p = ctx.one
|
||||
orig = ctx.prec
|
||||
try:
|
||||
ctx.prec = orig + 15
|
||||
while poles_num:
|
||||
i = poles_num.pop()
|
||||
j = poles_den.pop()
|
||||
p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
|
||||
for x in regular_num: p *= ctx.gamma(x)
|
||||
for x in regular_den: p /= ctx.gamma(x)
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
return +p
|
||||
|
||||
@defun
|
||||
def beta(ctx, x, y):
|
||||
x = ctx.convert(x)
|
||||
y = ctx.convert(y)
|
||||
if ctx.isinf(y):
|
||||
x, y = y, x
|
||||
if ctx.isinf(x):
|
||||
if x == ctx.inf and not ctx._im(y):
|
||||
if y == ctx.ninf:
|
||||
return ctx.nan
|
||||
if y > 0:
|
||||
return ctx.zero
|
||||
if ctx.isint(y):
|
||||
return ctx.nan
|
||||
if y < 0:
|
||||
return ctx.sign(ctx.gamma(y)) * ctx.inf
|
||||
return ctx.nan
|
||||
return ctx.gammaprod([x, y], [x+y])
|
||||
|
||||
@defun
|
||||
def binomial(ctx, n, k):
|
||||
return ctx.gammaprod([n+1], [k+1, n-k+1])
|
||||
|
||||
@defun
|
||||
def rf(ctx, x, n):
|
||||
return ctx.gammaprod([x+n], [x])
|
||||
|
||||
@defun
|
||||
def ff(ctx, x, n):
|
||||
return ctx.gammaprod([x+1], [x-n+1])
|
||||
|
||||
@defun_wrapped
|
||||
def fac2(ctx, x):
|
||||
if ctx.isinf(x):
|
||||
if x == ctx.inf:
|
||||
return x
|
||||
return ctx.nan
|
||||
return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
|
||||
|
||||
@defun_wrapped
|
||||
def barnesg(ctx, z):
|
||||
if ctx.isinf(z):
|
||||
if z == ctx.inf:
|
||||
return z
|
||||
return ctx.nan
|
||||
if ctx.isnan(z):
|
||||
return z
|
||||
if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
|
||||
return z*0
|
||||
# Account for size (would not be needed if computing log(G))
|
||||
if abs(z) > 5:
|
||||
ctx.dps += 2*ctx.log(abs(z),2)
|
||||
# Reflection formula
|
||||
if ctx.re(z) < -ctx.dps:
|
||||
w = 1-z
|
||||
pi2 = 2*ctx.pi
|
||||
u = ctx.expjpi(2*w)
|
||||
v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
|
||||
ctx.j*ctx.polylog(2, u)/pi2
|
||||
v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
|
||||
if ctx._is_real_type(z):
|
||||
v = ctx._re(v)
|
||||
return v
|
||||
# Estimate terms for asymptotic expansion
|
||||
# TODO: fixme, obviously
|
||||
N = ctx.dps // 2 + 5
|
||||
G = 1
|
||||
while abs(z) < N or ctx.re(z) < 1:
|
||||
G /= ctx.gamma(z)
|
||||
z += 1
|
||||
z -= 1
|
||||
s = ctx.mpf(1)/12
|
||||
s -= ctx.log(ctx.glaisher)
|
||||
s += z*ctx.log(2*ctx.pi)/2
|
||||
s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
|
||||
s -= 3*z**2/4
|
||||
z2k = z2 = z**2
|
||||
for k in xrange(1, N+1):
|
||||
t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
|
||||
if abs(t) < ctx.eps:
|
||||
#print k, N # check how many terms were needed
|
||||
break
|
||||
z2k *= z2
|
||||
s += t
|
||||
#if k == N:
|
||||
# print "warning: series for barnesg failed to converge", ctx.dps
|
||||
return G*ctx.exp(s)
|
||||
|
||||
@defun
|
||||
def superfac(ctx, z):
|
||||
return ctx.barnesg(z+2)
|
||||
|
||||
@defun_wrapped
|
||||
def hyperfac(ctx, z):
|
||||
# XXX: estimate needed extra bits accurately
|
||||
if z == ctx.inf:
|
||||
return z
|
||||
if abs(z) > 5:
|
||||
extra = 4*int(ctx.log(abs(z),2))
|
||||
else:
|
||||
extra = 0
|
||||
ctx.prec += extra
|
||||
if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
|
||||
n = int(ctx.re(z))
|
||||
h = ctx.hyperfac(-n-1)
|
||||
if ((n+1)//2) & 1:
|
||||
h = -h
|
||||
if ctx._is_complex_type(z):
|
||||
return h + 0j
|
||||
return h
|
||||
zp1 = z+1
|
||||
# Wrong branch cut
|
||||
#v = ctx.gamma(zp1)**z
|
||||
#ctx.prec -= extra
|
||||
#return v / ctx.barnesg(zp1)
|
||||
v = ctx.exp(z*ctx.loggamma(zp1))
|
||||
ctx.prec -= extra
|
||||
return v / ctx.barnesg(zp1)
|
||||
|
||||
@defun_wrapped
|
||||
def loggamma(ctx, z):
|
||||
a = ctx._re(z)
|
||||
b = ctx._im(z)
|
||||
if not b and a > 0:
|
||||
return ctx.ln(ctx.gamma(z))
|
||||
u = ctx.arg(z)
|
||||
w = ctx.ln(ctx.gamma(z))
|
||||
if b:
|
||||
gi = -b - u/2 + a*u + b*ctx.ln(abs(z))
|
||||
n = ctx.floor((gi-ctx._im(w))/(2*ctx.pi)+0.5) * (2*ctx.pi)
|
||||
return w + n*ctx.j
|
||||
elif a < 0:
|
||||
n = int(ctx.floor(a))
|
||||
w += (n-(n%2))*ctx.pi*ctx.j
|
||||
return w
|
||||
|
||||
'''
|
||||
@defun
|
||||
def psi0(ctx, z):
|
||||
"""Shortcut for psi(0,z) (the digamma function)"""
|
||||
return ctx.psi(0, z)
|
||||
|
||||
@defun
|
||||
def psi1(ctx, z):
|
||||
"""Shortcut for psi(1,z) (the trigamma function)"""
|
||||
return ctx.psi(1, z)
|
||||
|
||||
@defun
|
||||
def psi2(ctx, z):
|
||||
"""Shortcut for psi(2,z) (the tetragamma function)"""
|
||||
return ctx.psi(2, z)
|
||||
|
||||
@defun
|
||||
def psi3(ctx, z):
|
||||
"""Shortcut for psi(3,z) (the pentagamma function)"""
|
||||
return ctx.psi(3, z)
|
||||
'''
|
||||
|
|
@ -1,435 +0,0 @@
|
|||
class SpecialFunctions(object):
|
||||
"""
|
||||
This class implements special functions using high-level code.
|
||||
|
||||
Elementary and some other functions (e.g. gamma function, basecase
|
||||
hypergeometric series) are assumed to be predefined by the context as
|
||||
"builtins" or "low-level" functions.
|
||||
"""
|
||||
defined_functions = {}
|
||||
|
||||
# The series for the Jacobi theta functions converge for |q| < 1;
|
||||
# in the current implementation they throw a ValueError for
|
||||
# abs(q) > THETA_Q_LIM
|
||||
THETA_Q_LIM = 1 - 10**-7
|
||||
|
||||
def __init__(self):
|
||||
cls = self.__class__
|
||||
for name in cls.defined_functions:
|
||||
f, wrap = cls.defined_functions[name]
|
||||
cls._wrap_specfun(name, f, wrap)
|
||||
|
||||
self.mpq_1 = self._mpq((1,1))
|
||||
self.mpq_0 = self._mpq((0,1))
|
||||
self.mpq_1_2 = self._mpq((1,2))
|
||||
self.mpq_3_2 = self._mpq((3,2))
|
||||
self.mpq_1_4 = self._mpq((1,4))
|
||||
self.mpq_1_16 = self._mpq((1,16))
|
||||
self.mpq_3_16 = self._mpq((3,16))
|
||||
self.mpq_5_2 = self._mpq((5,2))
|
||||
self.mpq_3_4 = self._mpq((3,4))
|
||||
self.mpq_7_4 = self._mpq((7,4))
|
||||
self.mpq_5_4 = self._mpq((5,4))
|
||||
|
||||
self._aliases.update({
|
||||
'phase' : 'arg',
|
||||
'conjugate' : 'conj',
|
||||
'nthroot' : 'root',
|
||||
'polygamma' : 'psi',
|
||||
'hurwitz' : 'zeta',
|
||||
#'digamma' : 'psi0',
|
||||
#'trigamma' : 'psi1',
|
||||
#'tetragamma' : 'psi2',
|
||||
#'pentagamma' : 'psi3',
|
||||
'fibonacci' : 'fib',
|
||||
'factorial' : 'fac',
|
||||
})
|
||||
|
||||
# Default -- do nothing
|
||||
@classmethod
|
||||
def _wrap_specfun(cls, name, f, wrap):
|
||||
setattr(cls, name, f)
|
||||
|
||||
# Optional fast versions of common functions in common cases.
|
||||
# If not overridden, default (generic hypergeometric series)
|
||||
# implementations will be used
|
||||
def _besselj(ctx, n, z): raise NotImplementedError
|
||||
def _erf(ctx, z): raise NotImplementedError
|
||||
def _erfc(ctx, z): raise NotImplementedError
|
||||
def _gamma_upper_int(ctx, z, a): raise NotImplementedError
|
||||
def _expint_int(ctx, n, z): raise NotImplementedError
|
||||
def _zeta(ctx, s): raise NotImplementedError
|
||||
def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError
|
||||
def _ei(ctx, z): raise NotImplementedError
|
||||
def _e1(ctx, z): raise NotImplementedError
|
||||
def _ci(ctx, z): raise NotImplementedError
|
||||
def _si(ctx, z): raise NotImplementedError
|
||||
def _altzeta(ctx, s): raise NotImplementedError
|
||||
|
||||
def defun_wrapped(f):
|
||||
SpecialFunctions.defined_functions[f.__name__] = f, True
|
||||
|
||||
def defun(f):
|
||||
SpecialFunctions.defined_functions[f.__name__] = f, False
|
||||
|
||||
def defun_static(f):
|
||||
setattr(SpecialFunctions, f.__name__, f)
|
||||
|
||||
@defun_wrapped
|
||||
def cot(ctx, z): return ctx.one / ctx.tan(z)
|
||||
|
||||
@defun_wrapped
|
||||
def sec(ctx, z): return ctx.one / ctx.cos(z)
|
||||
|
||||
@defun_wrapped
|
||||
def csc(ctx, z): return ctx.one / ctx.sin(z)
|
||||
|
||||
@defun_wrapped
|
||||
def coth(ctx, z): return ctx.one / ctx.tanh(z)
|
||||
|
||||
@defun_wrapped
|
||||
def sech(ctx, z): return ctx.one / ctx.cosh(z)
|
||||
|
||||
@defun_wrapped
|
||||
def csch(ctx, z): return ctx.one / ctx.sinh(z)
|
||||
|
||||
@defun_wrapped
|
||||
def acot(ctx, z): return ctx.atan(ctx.one / z)
|
||||
|
||||
@defun_wrapped
|
||||
def asec(ctx, z): return ctx.acos(ctx.one / z)
|
||||
|
||||
@defun_wrapped
|
||||
def acsc(ctx, z): return ctx.asin(ctx.one / z)
|
||||
|
||||
@defun_wrapped
|
||||
def acoth(ctx, z): return ctx.atanh(ctx.one / z)
|
||||
|
||||
@defun_wrapped
|
||||
def asech(ctx, z): return ctx.acosh(ctx.one / z)
|
||||
|
||||
@defun_wrapped
|
||||
def acsch(ctx, z): return ctx.asinh(ctx.one / z)
|
||||
|
||||
@defun
|
||||
def sign(ctx, x):
|
||||
x = ctx.convert(x)
|
||||
if not x or ctx.isnan(x):
|
||||
return x
|
||||
if ctx._is_real_type(x):
|
||||
return ctx.mpf(cmp(x, 0))
|
||||
return x / abs(x)
|
||||
|
||||
@defun
|
||||
def agm(ctx, a, b=1):
|
||||
if b == 1:
|
||||
return ctx.agm1(a)
|
||||
a = ctx.convert(a)
|
||||
b = ctx.convert(b)
|
||||
return ctx._agm(a, b)
|
||||
|
||||
@defun_wrapped
|
||||
def sinc(ctx, x):
|
||||
if ctx.isinf(x):
|
||||
return 1/x
|
||||
if not x:
|
||||
return x+1
|
||||
return ctx.sin(x)/x
|
||||
|
||||
@defun_wrapped
|
||||
def sincpi(ctx, x):
|
||||
if ctx.isinf(x):
|
||||
return 1/x
|
||||
if not x:
|
||||
return x+1
|
||||
return ctx.sinpi(x)/(ctx.pi*x)
|
||||
|
||||
# TODO: tests; improve implementation
|
||||
@defun_wrapped
|
||||
def expm1(ctx, x):
|
||||
if not x:
|
||||
return ctx.zero
|
||||
# exp(x) - 1 ~ x
|
||||
if ctx.mag(x) < -ctx.prec:
|
||||
return x + 0.5*x**2
|
||||
# TODO: accurately eval the smaller of the real/imag parts
|
||||
return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1)
|
||||
|
||||
@defun_wrapped
|
||||
def powm1(ctx, x, y):
|
||||
mag = ctx.mag
|
||||
one = ctx.one
|
||||
w = x**y - one
|
||||
M = mag(w)
|
||||
# Only moderate cancellation
|
||||
if M > -8:
|
||||
return w
|
||||
# Check for the only possible exact cases
|
||||
if not w:
|
||||
if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)):
|
||||
return w
|
||||
x1 = x - one
|
||||
magy = mag(y)
|
||||
lnx = ctx.ln(x)
|
||||
# Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2)
|
||||
if magy + mag(lnx) < -ctx.prec:
|
||||
return lnx*y + (lnx*y)**2/2
|
||||
# TODO: accurately eval the smaller of the real/imag part
|
||||
return ctx.sum_accurately(lambda: iter([x**y, -1]), 1)
|
||||
|
||||
@defun
|
||||
def _rootof1(ctx, k, n):
|
||||
k = int(k)
|
||||
n = int(n)
|
||||
k %= n
|
||||
if not k:
|
||||
return ctx.one
|
||||
elif 2*k == n:
|
||||
return -ctx.one
|
||||
elif 4*k == n:
|
||||
return ctx.j
|
||||
elif 4*k == 3*n:
|
||||
return -ctx.j
|
||||
return ctx.expjpi(2*ctx.mpf(k)/n)
|
||||
|
||||
@defun
|
||||
def root(ctx, x, n, k=0):
|
||||
n = int(n)
|
||||
x = ctx.convert(x)
|
||||
if k:
|
||||
# Special case: there is an exact real root
|
||||
if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0):
|
||||
return -ctx.root(-x, n)
|
||||
# Multiply by root of unity
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
v = ctx.root(x, n, 0) * ctx._rootof1(k, n)
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return +v
|
||||
return ctx._nthroot(x, n)
|
||||
|
||||
@defun
|
||||
def unitroots(ctx, n, primitive=False):
|
||||
gcd = ctx._gcd
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
if primitive:
|
||||
v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1]
|
||||
else:
|
||||
# TODO: this can be done *much* faster
|
||||
v = [ctx._rootof1(k,n) for k in range(n)]
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return [+x for x in v]
|
||||
|
||||
@defun
|
||||
def arg(ctx, x):
|
||||
x = ctx.convert(x)
|
||||
re = ctx._re(x)
|
||||
im = ctx._im(x)
|
||||
return ctx.atan2(im, re)
|
||||
|
||||
@defun
|
||||
def fabs(ctx, x):
|
||||
return abs(ctx.convert(x))
|
||||
|
||||
@defun
|
||||
def re(ctx, x):
|
||||
x = ctx.convert(x)
|
||||
if hasattr(x, "real"): # py2.5 doesn't have .real/.imag for all numbers
|
||||
return x.real
|
||||
return x
|
||||
|
||||
@defun
|
||||
def im(ctx, x):
|
||||
x = ctx.convert(x)
|
||||
if hasattr(x, "imag"): # py2.5 doesn't have .real/.imag for all numbers
|
||||
return x.imag
|
||||
return ctx.zero
|
||||
|
||||
@defun
|
||||
def conj(ctx, x):
|
||||
return ctx.convert(x).conjugate()
|
||||
|
||||
@defun
|
||||
def polar(ctx, z):
|
||||
return (ctx.fabs(z), ctx.arg(z))
|
||||
|
||||
@defun_wrapped
|
||||
def rect(ctx, r, phi):
|
||||
return r * ctx.mpc(*ctx.cos_sin(phi))
|
||||
|
||||
@defun
|
||||
def log(ctx, x, b=None):
|
||||
if b is None:
|
||||
return ctx.ln(x)
|
||||
wp = ctx.prec + 20
|
||||
return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp)
|
||||
|
||||
@defun
|
||||
def log10(ctx, x):
|
||||
return ctx.log(x, 10)
|
||||
|
||||
@defun
|
||||
def modf(ctx, x, y):
|
||||
return ctx.convert(x) % ctx.convert(y)
|
||||
|
||||
@defun
|
||||
def degrees(ctx, x):
|
||||
return x / ctx.degree
|
||||
|
||||
@defun
|
||||
def radians(ctx, x):
|
||||
return x * ctx.degree
|
||||
|
||||
@defun_wrapped
|
||||
def lambertw(ctx, z, k=0):
|
||||
k = int(k)
|
||||
if ctx.isnan(z):
|
||||
return z
|
||||
ctx.prec += 20
|
||||
mag = ctx.mag(z)
|
||||
# Start from fp approximation
|
||||
if ctx is ctx._mp and abs(mag) < 900 and abs(k) < 10000 and \
|
||||
abs(z+0.36787944117144) > 0.01:
|
||||
w = ctx._fp.lambertw(z, k)
|
||||
else:
|
||||
absz = abs(z)
|
||||
# We must be extremely careful near the singularities at -1/e and 0
|
||||
u = ctx.exp(-1)
|
||||
if absz <= u:
|
||||
if not z:
|
||||
# w(0,0) = 0; for all other branches we hit the pole
|
||||
if not k:
|
||||
return z
|
||||
return ctx.ninf
|
||||
if not k:
|
||||
w = z
|
||||
# For small real z < 0, the -1 branch aves roughly like log(-z)
|
||||
elif k == -1 and not ctx.im(z) and ctx.re(z) < 0:
|
||||
w = ctx.ln(-z)
|
||||
# Use a simple asymptotic approximation.
|
||||
else:
|
||||
w = ctx.ln(z)
|
||||
# The branches are roughly logarithmic. This approximation
|
||||
# gets better for large |k|; need to check that this always
|
||||
# works for k ~= -1, 0, 1.
|
||||
if k: w += k * 2*ctx.pi*ctx.j
|
||||
elif k == 0 and ctx.im(z) and absz <= 0.7:
|
||||
# Both the W(z) ~= z and W(z) ~= ln(z) approximations break
|
||||
# down around z ~= -0.5 (converging to the wrong branch), so patch
|
||||
# with a constant approximation (adjusted for sign)
|
||||
if abs(z+0.5) < 0.1:
|
||||
if ctx.im(z) > 0:
|
||||
w = ctx.mpc(0.7+0.7j)
|
||||
else:
|
||||
w = ctx.mpc(0.7-0.7j)
|
||||
else:
|
||||
w = z
|
||||
else:
|
||||
if z == ctx.inf:
|
||||
if k == 0:
|
||||
return z
|
||||
else:
|
||||
return z + 2*k*ctx.pi*ctx.j
|
||||
if z == ctx.ninf:
|
||||
return (-z) + (2*k+1)*ctx.pi*ctx.j
|
||||
# Simple asymptotic approximation as above
|
||||
w = ctx.ln(z)
|
||||
if k:
|
||||
w += k * 2*ctx.pi*ctx.j
|
||||
# Use Halley iteration to solve w*exp(w) = z
|
||||
two = ctx.mpf(2)
|
||||
weps = ctx.ldexp(ctx.eps, 15)
|
||||
for i in xrange(100):
|
||||
ew = ctx.exp(w)
|
||||
wew = w*ew
|
||||
wewz = wew-z
|
||||
wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
|
||||
if abs(wn-w) < weps*abs(wn):
|
||||
return wn
|
||||
else:
|
||||
w = wn
|
||||
ctx.warn("Lambert W iteration failed to converge for %s" % z)
|
||||
return wn
|
||||
|
||||
@defun_wrapped
|
||||
def bell(ctx, n, x=1):
|
||||
x = ctx.convert(x)
|
||||
if not n:
|
||||
if ctx.isnan(x):
|
||||
return x
|
||||
return type(x)(1)
|
||||
if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n):
|
||||
return x**n
|
||||
if n == 1: return x
|
||||
if n == 2: return x*(x+1)
|
||||
if x == 0: return ctx.sincpi(n)
|
||||
return _polyexp(ctx, n, x, True) / ctx.exp(x)
|
||||
|
||||
def _polyexp(ctx, n, x, extra=False):
|
||||
def _terms():
|
||||
if extra:
|
||||
yield ctx.sincpi(n)
|
||||
t = x
|
||||
k = 1
|
||||
while 1:
|
||||
yield k**n * t
|
||||
k += 1
|
||||
t = t*x/k
|
||||
return ctx.sum_accurately(_terms, check_step=4)
|
||||
|
||||
@defun_wrapped
|
||||
def polyexp(ctx, s, z):
|
||||
if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s):
|
||||
return z**s
|
||||
if z == 0: return z*s
|
||||
if s == 0: return ctx.expm1(z)
|
||||
if s == 1: return ctx.exp(z)*z
|
||||
if s == 2: return ctx.exp(z)*z*(z+1)
|
||||
return _polyexp(ctx, s, z)
|
||||
|
||||
@defun_wrapped
|
||||
def cyclotomic(ctx, n, z):
|
||||
n = int(n)
|
||||
assert n >= 0
|
||||
p = ctx.one
|
||||
if n == 0:
|
||||
return p
|
||||
if n == 1:
|
||||
return z - p
|
||||
if n == 2:
|
||||
return z + p
|
||||
# Use divisor product representation. Unfortunately, this sometimes
|
||||
# includes singularities for roots of unity, which we have to cancel out.
|
||||
# Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1).
|
||||
a_prod = 1
|
||||
b_prod = 1
|
||||
num_zeros = 0
|
||||
num_poles = 0
|
||||
for d in range(1,n+1):
|
||||
if not n % d:
|
||||
w = ctx.moebius(n//d)
|
||||
# Use powm1 because it is important that we get 0 only
|
||||
# if it really is exactly 0
|
||||
b = -ctx.powm1(z, d)
|
||||
if b:
|
||||
p *= b**w
|
||||
else:
|
||||
if w == 1:
|
||||
a_prod *= d
|
||||
num_zeros += 1
|
||||
elif w == -1:
|
||||
b_prod *= d
|
||||
num_poles += 1
|
||||
#print n, num_zeros, num_poles
|
||||
if num_zeros:
|
||||
if num_zeros > num_poles:
|
||||
p *= 0
|
||||
else:
|
||||
p *= a_prod
|
||||
p /= b_prod
|
||||
return p
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,693 +0,0 @@
|
|||
from functions import defun, defun_wrapped, defun_static
|
||||
|
||||
@defun
|
||||
def stieltjes(ctx, n, a=1):
|
||||
n = ctx.convert(n)
|
||||
a = ctx.convert(a)
|
||||
if n < 0:
|
||||
return ctx.bad_domain("Stieltjes constants defined for n >= 0")
|
||||
if hasattr(ctx, "stieltjes_cache"):
|
||||
stieltjes_cache = ctx.stieltjes_cache
|
||||
else:
|
||||
stieltjes_cache = ctx.stieltjes_cache = {}
|
||||
if a == 1:
|
||||
if n == 0:
|
||||
return +ctx.euler
|
||||
if n in stieltjes_cache:
|
||||
prec, s = stieltjes_cache[n]
|
||||
if prec >= ctx.prec:
|
||||
return +s
|
||||
mag = 1
|
||||
def f(x):
|
||||
xa = x/a
|
||||
v = (xa-ctx.j)*ctx.ln(a-ctx.j*x)**n/(1+xa**2)/(ctx.exp(2*ctx.pi*x)-1)
|
||||
return ctx._re(v) / mag
|
||||
orig = ctx.prec
|
||||
try:
|
||||
# Normalize integrand by approx. magnitude to
|
||||
# speed up quadrature (which uses absolute error)
|
||||
if n > 50:
|
||||
ctx.prec = 20
|
||||
mag = ctx.quad(f, [0,ctx.inf], maxdegree=3)
|
||||
ctx.prec = orig + 10 + int(n**0.5)
|
||||
s = ctx.quad(f, [0,ctx.inf], maxdegree=20)
|
||||
v = ctx.ln(a)**n/(2*a) - ctx.ln(a)**(n+1)/(n+1) + 2*s/a*mag
|
||||
finally:
|
||||
ctx.prec = orig
|
||||
if a == 1 and ctx.isint(n):
|
||||
stieltjes_cache[n] = (ctx.prec, v)
|
||||
return +v
|
||||
|
||||
@defun_wrapped
|
||||
def siegeltheta(ctx, t):
|
||||
if ctx._im(t):
|
||||
# XXX: cancellation occurs
|
||||
a = ctx.loggamma(0.25+0.5j*t)
|
||||
b = ctx.loggamma(0.25-0.5j*t)
|
||||
return -ctx.ln(ctx.pi)/2*t - 0.5j*(a-b)
|
||||
else:
|
||||
if ctx.isinf(t):
|
||||
return t
|
||||
return ctx._im(ctx.loggamma(0.25+0.5j*t)) - ctx.ln(ctx.pi)/2*t
|
||||
|
||||
@defun_wrapped
|
||||
def grampoint(ctx, n):
|
||||
# asymptotic expansion, from
|
||||
# http://mathworld.wolfram.com/GramPoint.html
|
||||
g = 2*ctx.pi*ctx.exp(1+ctx.lambertw((8*n+1)/(8*ctx.e)))
|
||||
return ctx.findroot(lambda t: ctx.siegeltheta(t)-ctx.pi*n, g)
|
||||
|
||||
@defun_wrapped
|
||||
def siegelz(ctx, t):
|
||||
v = ctx.expj(ctx.siegeltheta(t))*ctx.zeta(0.5+ctx.j*t)
|
||||
if ctx._is_real_type(t):
|
||||
return ctx._re(v)
|
||||
return v
|
||||
|
||||
_zeta_zeros = [
|
||||
14.134725142,21.022039639,25.010857580,30.424876126,32.935061588,
|
||||
37.586178159,40.918719012,43.327073281,48.005150881,49.773832478,
|
||||
52.970321478,56.446247697,59.347044003,60.831778525,65.112544048,
|
||||
67.079810529,69.546401711,72.067157674,75.704690699,77.144840069,
|
||||
79.337375020,82.910380854,84.735492981,87.425274613,88.809111208,
|
||||
92.491899271,94.651344041,95.870634228,98.831194218,101.317851006,
|
||||
103.725538040,105.446623052,107.168611184,111.029535543,111.874659177,
|
||||
114.320220915,116.226680321,118.790782866,121.370125002,122.946829294,
|
||||
124.256818554,127.516683880,129.578704200,131.087688531,133.497737203,
|
||||
134.756509753,138.116042055,139.736208952,141.123707404,143.111845808,
|
||||
146.000982487,147.422765343,150.053520421,150.925257612,153.024693811,
|
||||
156.112909294,157.597591818,158.849988171,161.188964138,163.030709687,
|
||||
165.537069188,167.184439978,169.094515416,169.911976479,173.411536520,
|
||||
174.754191523,176.441434298,178.377407776,179.916484020,182.207078484,
|
||||
184.874467848,185.598783678,187.228922584,189.416158656,192.026656361,
|
||||
193.079726604,195.265396680,196.876481841,198.015309676,201.264751944,
|
||||
202.493594514,204.189671803,205.394697202,207.906258888,209.576509717,
|
||||
211.690862595,213.347919360,214.547044783,216.169538508,219.067596349,
|
||||
220.714918839,221.430705555,224.007000255,224.983324670,227.421444280,
|
||||
229.337413306,231.250188700,231.987235253,233.693404179,236.524229666,
|
||||
]
|
||||
|
||||
def _load_zeta_zeros(url):
|
||||
import urllib
|
||||
d = urllib.urlopen(url)
|
||||
L = [float(x) for x in d.readlines()]
|
||||
# Sanity check
|
||||
assert round(L[0]) == 14
|
||||
_zeta_zeros[:] = L
|
||||
|
||||
@defun
|
||||
def zetazero(ctx, n, url='http://www.dtc.umn.edu/~odlyzko/zeta_tables/zeros1'):
|
||||
n = int(n)
|
||||
if n < 0:
|
||||
return ctx.zetazero(-n).conjugate()
|
||||
if n == 0:
|
||||
raise ValueError("n must be nonzero")
|
||||
if n > len(_zeta_zeros) and n <= 100000:
|
||||
_load_zeta_zeros(url)
|
||||
if n > len(_zeta_zeros):
|
||||
raise NotImplementedError("n too large for zetazeros")
|
||||
return ctx.mpc(0.5, ctx.findroot(ctx.siegelz, _zeta_zeros[n-1]))
|
||||
|
||||
@defun_wrapped
|
||||
def riemannr(ctx, x):
|
||||
if x == 0:
|
||||
return ctx.zero
|
||||
# Check if a simple asymptotic estimate is accurate enough
|
||||
if abs(x) > 1000:
|
||||
a = ctx.li(x)
|
||||
b = 0.5*ctx.li(ctx.sqrt(x))
|
||||
if abs(b) < abs(a)*ctx.eps:
|
||||
return a
|
||||
if abs(x) < 0.01:
|
||||
# XXX
|
||||
ctx.prec += int(-ctx.log(abs(x),2))
|
||||
# Sum Gram's series
|
||||
s = t = ctx.one
|
||||
u = ctx.ln(x)
|
||||
k = 1
|
||||
while abs(t) > abs(s)*ctx.eps:
|
||||
t = t * u / k
|
||||
s += t / (k * ctx._zeta_int(k+1))
|
||||
k += 1
|
||||
return s
|
||||
|
||||
@defun_static
|
||||
def primepi(ctx, x):
|
||||
x = int(x)
|
||||
if x < 2:
|
||||
return 0
|
||||
return len(ctx.list_primes(x))
|
||||
|
||||
@defun_wrapped
|
||||
def primepi2(ctx, x):
|
||||
x = int(x)
|
||||
if x < 2:
|
||||
return ctx.mpi(0,0)
|
||||
if x < 2657:
|
||||
return ctx.mpi(ctx.primepi(x))
|
||||
mid = ctx.li(x)
|
||||
# Schoenfeld's estimate for x >= 2657, assuming RH
|
||||
err = ctx.sqrt(x,rounding='u')*ctx.ln(x,rounding='u')/8/ctx.pi(rounding='d')
|
||||
a = ctx.floor((ctx.mpi(mid)-err).a, rounding='d')
|
||||
b = ctx.ceil((ctx.mpi(mid)+err).b, rounding='u')
|
||||
return ctx.mpi(a, b)
|
||||
|
||||
@defun_wrapped
|
||||
def primezeta(ctx, s):
|
||||
if ctx.isnan(s):
|
||||
return s
|
||||
if ctx.re(s) <= 0:
|
||||
raise ValueError("prime zeta function defined only for re(s) > 0")
|
||||
if s == 1:
|
||||
return ctx.inf
|
||||
if s == 0.5:
|
||||
return ctx.mpc(ctx.ninf, ctx.pi)
|
||||
r = ctx.re(s)
|
||||
if r > ctx.prec:
|
||||
return 0.5**s
|
||||
else:
|
||||
wp = ctx.prec + int(r)
|
||||
def terms():
|
||||
orig = ctx.prec
|
||||
# zeta ~ 1+eps; need to set precision
|
||||
# to get logarithm accurately
|
||||
k = 0
|
||||
while 1:
|
||||
k += 1
|
||||
u = ctx.moebius(k)
|
||||
if not u:
|
||||
continue
|
||||
ctx.prec = wp
|
||||
t = u*ctx.ln(ctx.zeta(k*s))/k
|
||||
if not t:
|
||||
return
|
||||
#print ctx.prec, ctx.nstr(t)
|
||||
ctx.prec = orig
|
||||
yield t
|
||||
return ctx.sum_accurately(terms)
|
||||
|
||||
# TODO: for bernpoly and eulerpoly, ensure that all exact zeros are covered
|
||||
|
||||
@defun_wrapped
|
||||
def bernpoly(ctx, n, z):
|
||||
# Slow implementation:
|
||||
#return sum(ctx.binomial(n,k)*ctx.bernoulli(k)*z**(n-k) for k in xrange(0,n+1))
|
||||
n = int(n)
|
||||
if n < 0:
|
||||
raise ValueError("Bernoulli polynomials only defined for n >= 0")
|
||||
if z == 0 or (z == 1 and n > 1):
|
||||
return ctx.bernoulli(n)
|
||||
if z == 0.5:
|
||||
return (ctx.ldexp(1,1-n)-1)*ctx.bernoulli(n)
|
||||
if n <= 3:
|
||||
if n == 0: return z ** 0
|
||||
if n == 1: return z - 0.5
|
||||
if n == 2: return (6*z*(z-1)+1)/6
|
||||
if n == 3: return z*(z*(z-1.5)+0.5)
|
||||
if abs(z) == ctx.inf:
|
||||
return z ** n
|
||||
if z != z:
|
||||
return z
|
||||
if abs(z) > 2:
|
||||
def terms():
|
||||
t = ctx.one
|
||||
yield t
|
||||
r = ctx.one/z
|
||||
k = 1
|
||||
while k <= n:
|
||||
t = t*(n+1-k)/k*r
|
||||
if not (k > 2 and k & 1):
|
||||
yield t*ctx.bernoulli(k)
|
||||
k += 1
|
||||
return ctx.sum_accurately(terms) * z**n
|
||||
else:
|
||||
def terms():
|
||||
yield ctx.bernoulli(n)
|
||||
t = ctx.one
|
||||
k = 1
|
||||
while k <= n:
|
||||
t = t*(n+1-k)/k * z
|
||||
m = n-k
|
||||
if not (m > 2 and m & 1):
|
||||
yield t*ctx.bernoulli(m)
|
||||
k += 1
|
||||
return ctx.sum_accurately(terms)
|
||||
|
||||
@defun_wrapped
|
||||
def eulerpoly(ctx, n, z):
|
||||
n = int(n)
|
||||
if n < 0:
|
||||
raise ValueError("Euler polynomials only defined for n >= 0")
|
||||
if n <= 2:
|
||||
if n == 0: return z ** 0
|
||||
if n == 1: return z - 0.5
|
||||
if n == 2: return z*(z-1)
|
||||
if abs(z) == ctx.inf:
|
||||
return z**n
|
||||
if z != z:
|
||||
return z
|
||||
m = n+1
|
||||
if z == 0:
|
||||
return -2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
|
||||
if z == 1:
|
||||
return 2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
|
||||
if z == 0.5:
|
||||
if n % 2:
|
||||
return ctx.zero
|
||||
# Use exact code for Euler numbers
|
||||
if n < 100 or n*ctx.mag(0.46839865*n) < ctx.prec*0.25:
|
||||
return ctx.ldexp(ctx._eulernum(n), -n)
|
||||
# http://functions.wolfram.com/Polynomials/EulerE2/06/01/02/01/0002/
|
||||
def terms():
|
||||
t = ctx.one
|
||||
k = 0
|
||||
w = ctx.ldexp(1,n+2)
|
||||
while 1:
|
||||
v = n-k+1
|
||||
if not (v > 2 and v & 1):
|
||||
yield (2-w)*ctx.bernoulli(v)*t
|
||||
k += 1
|
||||
if k > n:
|
||||
break
|
||||
t = t*z*(n-k+2)/k
|
||||
w *= 0.5
|
||||
return ctx.sum_accurately(terms) / m
|
||||
|
||||
@defun
|
||||
def eulernum(ctx, n, exact=False):
|
||||
n = int(n)
|
||||
if exact:
|
||||
return int(ctx._eulernum(n))
|
||||
if n < 100:
|
||||
return ctx.mpf(ctx._eulernum(n))
|
||||
if n % 2:
|
||||
return ctx.zero
|
||||
return ctx.ldexp(ctx.eulerpoly(n,0.5), n)
|
||||
|
||||
# TODO: this should be implemented low-level
|
||||
def polylog_series(ctx, s, z):
|
||||
tol = +ctx.eps
|
||||
l = ctx.zero
|
||||
k = 1
|
||||
zk = z
|
||||
while 1:
|
||||
term = zk / k**s
|
||||
l += term
|
||||
if abs(term) < tol:
|
||||
break
|
||||
zk *= z
|
||||
k += 1
|
||||
return l
|
||||
|
||||
def polylog_continuation(ctx, n, z):
|
||||
if n < 0:
|
||||
return z*0
|
||||
twopij = 2j * ctx.pi
|
||||
a = -twopij**n/ctx.fac(n) * ctx.bernpoly(n, ctx.ln(z)/twopij)
|
||||
if ctx._is_real_type(z) and z < 0:
|
||||
a = ctx._re(a)
|
||||
if ctx._im(z) < 0 or (ctx._im(z) == 0 and ctx._re(z) >= 1):
|
||||
a -= twopij*ctx.ln(z)**(n-1)/ctx.fac(n-1)
|
||||
return a
|
||||
|
||||
def polylog_unitcircle(ctx, n, z):
|
||||
tol = +ctx.eps
|
||||
if n > 1:
|
||||
l = ctx.zero
|
||||
logz = ctx.ln(z)
|
||||
logmz = ctx.one
|
||||
m = 0
|
||||
while 1:
|
||||
if (n-m) != 1:
|
||||
term = ctx.zeta(n-m) * logmz / ctx.fac(m)
|
||||
if term and abs(term) < tol:
|
||||
break
|
||||
l += term
|
||||
logmz *= logz
|
||||
m += 1
|
||||
l += ctx.ln(z)**(n-1)/ctx.fac(n-1)*(ctx.harmonic(n-1)-ctx.ln(-ctx.ln(z)))
|
||||
elif n < 1: # else
|
||||
l = ctx.fac(-n)*(-ctx.ln(z))**(n-1)
|
||||
logz = ctx.ln(z)
|
||||
logkz = ctx.one
|
||||
k = 0
|
||||
while 1:
|
||||
b = ctx.bernoulli(k-n+1)
|
||||
if b:
|
||||
term = b*logkz/(ctx.fac(k)*(k-n+1))
|
||||
if abs(term) < tol:
|
||||
break
|
||||
l -= term
|
||||
logkz *= logz
|
||||
k += 1
|
||||
else:
|
||||
raise ValueError
|
||||
if ctx._is_real_type(z) and z < 0:
|
||||
l = ctx._re(l)
|
||||
return l
|
||||
|
||||
def polylog_general(ctx, s, z):
|
||||
v = ctx.zero
|
||||
u = ctx.ln(z)
|
||||
if not abs(u) < 5: # theoretically |u| < 2*pi
|
||||
raise NotImplementedError("polylog for arbitrary s and z")
|
||||
t = 1
|
||||
k = 0
|
||||
while 1:
|
||||
term = ctx.zeta(s-k) * t
|
||||
if abs(term) < ctx.eps:
|
||||
break
|
||||
v += term
|
||||
k += 1
|
||||
t *= u
|
||||
t /= k
|
||||
return ctx.gamma(1-s)*(-u)**(s-1) + v
|
||||
|
||||
@defun_wrapped
|
||||
def polylog(ctx, s, z):
|
||||
s = ctx.convert(s)
|
||||
z = ctx.convert(z)
|
||||
if z == 1:
|
||||
return ctx.zeta(s)
|
||||
if z == -1:
|
||||
return -ctx.altzeta(s)
|
||||
if s == 0:
|
||||
return z/(1-z)
|
||||
if s == 1:
|
||||
return -ctx.ln(1-z)
|
||||
if s == -1:
|
||||
return z/(1-z)**2
|
||||
if abs(z) <= 0.75 or (not ctx.isint(s) and abs(z) < 0.9):
|
||||
return polylog_series(ctx, s, z)
|
||||
if abs(z) >= 1.4 and ctx.isint(s):
|
||||
return (-1)**(s+1)*polylog_series(ctx, s, 1/z) + polylog_continuation(ctx, s, z)
|
||||
if ctx.isint(s):
|
||||
return polylog_unitcircle(ctx, int(s), z)
|
||||
return polylog_general(ctx, s, z)
|
||||
|
||||
#raise NotImplementedError("polylog for arbitrary s and z")
|
||||
# This could perhaps be used in some cases
|
||||
#from quadrature import quad
|
||||
#return quad(lambda t: t**(s-1)/(exp(t)/z-1),[0,inf])/gamma(s)
|
||||
|
||||
@defun_wrapped
|
||||
def clsin(ctx, s, z, pi=False):
|
||||
if ctx.isint(s) and s < 0 and int(s) % 2 == 1:
|
||||
return z*0
|
||||
if pi:
|
||||
a = ctx.expjpi(z)
|
||||
else:
|
||||
a = ctx.expj(z)
|
||||
if ctx._is_real_type(z) and ctx._is_real_type(s):
|
||||
return ctx.im(ctx.polylog(s,a))
|
||||
b = 1/a
|
||||
return (-0.5j)*(ctx.polylog(s,a) - ctx.polylog(s,b))
|
||||
|
||||
@defun_wrapped
|
||||
def clcos(ctx, s, z, pi=False):
|
||||
if ctx.isint(s) and s < 0 and int(s) % 2 == 0:
|
||||
return z*0
|
||||
if pi:
|
||||
a = ctx.expjpi(z)
|
||||
else:
|
||||
a = ctx.expj(z)
|
||||
if ctx._is_real_type(z) and ctx._is_real_type(s):
|
||||
return ctx.re(ctx.polylog(s,a))
|
||||
b = 1/a
|
||||
return 0.5*(ctx.polylog(s,a) + ctx.polylog(s,b))
|
||||
|
||||
@defun
|
||||
def altzeta(ctx, s, **kwargs):
|
||||
try:
|
||||
return ctx._altzeta(s, **kwargs)
|
||||
except NotImplementedError:
|
||||
return ctx._altzeta_generic(s)
|
||||
|
||||
@defun_wrapped
|
||||
def _altzeta_generic(ctx, s):
|
||||
if s == 1:
|
||||
return ctx.ln2 + 0*s
|
||||
return -ctx.powm1(2, 1-s) * ctx.zeta(s)
|
||||
|
||||
@defun
|
||||
def zeta(ctx, s, a=1, derivative=0, method=None, **kwargs):
|
||||
d = int(derivative)
|
||||
if a == 1 and not (d or method):
|
||||
try:
|
||||
return ctx._zeta(s, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
s = ctx.convert(s)
|
||||
prec = ctx.prec
|
||||
method = kwargs.get('method')
|
||||
verbose = kwargs.get('verbose')
|
||||
if a == 1 and method != 'euler-maclaurin':
|
||||
im = abs(ctx._im(s))
|
||||
re = abs(ctx._re(s))
|
||||
#if (im < prec or method == 'borwein') and not derivative:
|
||||
# try:
|
||||
# if verbose:
|
||||
# print "zeta: Attempting to use the Borwein algorithm"
|
||||
# return ctx._zeta(s, **kwargs)
|
||||
# except NotImplementedError:
|
||||
# if verbose:
|
||||
# print "zeta: Could not use the Borwein algorithm"
|
||||
# pass
|
||||
if abs(im) > 60*prec and 10*re < prec and derivative <= 4 or \
|
||||
method == 'riemann-siegel':
|
||||
try: # py2.4 compatible try block
|
||||
try:
|
||||
if verbose:
|
||||
print "zeta: Attempting to use the Riemann-Siegel algorithm"
|
||||
return ctx.rs_zeta(s, derivative, **kwargs)
|
||||
except NotImplementedError:
|
||||
if verbose:
|
||||
print "zeta: Could not use the Riemann-Siegel algorithm"
|
||||
pass
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
if s == 1:
|
||||
return ctx.inf
|
||||
abss = abs(s)
|
||||
if abss == ctx.inf:
|
||||
if ctx.re(s) == ctx.inf:
|
||||
if d == 0:
|
||||
return ctx.one
|
||||
return ctx.zero
|
||||
return s*0
|
||||
elif ctx.isnan(abss):
|
||||
return 1/s
|
||||
if ctx.re(s) > 2*ctx.prec and a == 1 and not derivative:
|
||||
return ctx.one + ctx.power(2, -s)
|
||||
if verbose:
|
||||
print "zeta: Using the Euler-Maclaurin algorithm"
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
v = ctx._hurwitz(s, a, d)
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return +v
|
||||
|
||||
@defun
|
||||
def _hurwitz(ctx, s, a=1, d=0):
|
||||
# We strongly want to special-case rational a
|
||||
a, atype = ctx._convert_param(a)
|
||||
prec = ctx.prec
|
||||
# TODO: implement reflection for derivatives
|
||||
res = ctx.re(s)
|
||||
negs = -s
|
||||
try:
|
||||
if res < 0 and not d:
|
||||
# Integer reflection formula
|
||||
if ctx.isnpint(s):
|
||||
n = int(res)
|
||||
if n <= 0:
|
||||
return ctx.bernpoly(1-n, a) / (n-1)
|
||||
t = 1-s
|
||||
# We now require a to be standardized
|
||||
v = 0
|
||||
shift = 0
|
||||
b = a
|
||||
while ctx.re(b) > 1:
|
||||
b -= 1
|
||||
v -= b**negs
|
||||
shift -= 1
|
||||
while ctx.re(b) <= 0:
|
||||
v += b**negs
|
||||
b += 1
|
||||
shift += 1
|
||||
# Rational reflection formula
|
||||
if atype == 'Q' or atype == 'Z':
|
||||
try:
|
||||
p, q = a
|
||||
except:
|
||||
assert a == int(a)
|
||||
p = int(a)
|
||||
q = 1
|
||||
p += shift*q
|
||||
assert 1 <= p <= q
|
||||
g = ctx.fsum(ctx.cospi(t/2-2*k*b)*ctx._hurwitz(t,(k,q)) \
|
||||
for k in range(1,q+1))
|
||||
g *= 2*ctx.gamma(t)/(2*ctx.pi*q)**t
|
||||
v += g
|
||||
return v
|
||||
# General reflection formula
|
||||
else:
|
||||
C1 = ctx.cospi(t/2)
|
||||
C2 = ctx.sinpi(t/2)
|
||||
# Clausen functions; could maybe use polylog directly
|
||||
if C1: C1 *= ctx.clcos(t, 2*a, pi=True)
|
||||
if C2: C2 *= ctx.clsin(t, 2*a, pi=True)
|
||||
v += 2*ctx.gamma(t)/(2*ctx.pi)**t*(C1+C2)
|
||||
return v
|
||||
except NotImplementedError:
|
||||
pass
|
||||
a = ctx.convert(a)
|
||||
tol = -prec
|
||||
# Estimate number of terms for Euler-Maclaurin summation; could be improved
|
||||
M1 = 0
|
||||
M2 = prec // 3
|
||||
N = M2
|
||||
lsum = 0
|
||||
# This speeds up the recurrence for derivatives
|
||||
if ctx.isint(s):
|
||||
s = int(ctx._re(s))
|
||||
s1 = s-1
|
||||
while 1:
|
||||
# Truncated L-series
|
||||
l = ctx._zetasum(s, M1+a, M2-M1-1, [d])[0][0]
|
||||
#if d:
|
||||
# l = ctx.fsum((-ctx.ln(n+a))**d * (n+a)**negs for n in range(M1,M2))
|
||||
#else:
|
||||
# l = ctx.fsum((n+a)**negs for n in range(M1,M2))
|
||||
lsum += l
|
||||
M2a = M2+a
|
||||
logM2a = ctx.ln(M2a)
|
||||
logM2ad = logM2a**d
|
||||
logs = [logM2ad]
|
||||
logr = 1/logM2a
|
||||
rM2a = 1/M2a
|
||||
M2as = rM2a**s
|
||||
if d:
|
||||
tailsum = ctx.gammainc(d+1, s1*logM2a) / s1**(d+1)
|
||||
else:
|
||||
tailsum = 1/((s1)*(M2a)**s1)
|
||||
tailsum += 0.5 * logM2ad * M2as
|
||||
U = [1]
|
||||
r = M2as
|
||||
fact = 2
|
||||
for j in range(1, N+1):
|
||||
# TODO: the following could perhaps be tidied a bit
|
||||
j2 = 2*j
|
||||
if j == 1:
|
||||
upds = [1]
|
||||
else:
|
||||
upds = [j2-2, j2-1]
|
||||
for m in upds:
|
||||
D = min(m,d+1)
|
||||
if m <= d:
|
||||
logs.append(logs[-1] * logr)
|
||||
Un = [0]*(D+1)
|
||||
for i in xrange(D): Un[i] = (1-m-s)*U[i]
|
||||
for i in xrange(1,D+1): Un[i] += (d-(i-1))*U[i-1]
|
||||
U = Un
|
||||
r *= rM2a
|
||||
t = ctx.fdot(U, logs) * r * ctx.bernoulli(j2)/(-fact)
|
||||
tailsum += t
|
||||
if ctx.mag(t) < tol:
|
||||
return lsum + (-1)**d * tailsum
|
||||
fact *= (j2+1)*(j2+2)
|
||||
M1, M2 = M2, M2*2
|
||||
|
||||
@defun
|
||||
def _zetasum(ctx, s, a, n, derivatives=[0], reflect=False):
|
||||
"""
|
||||
Returns [xd0,xd1,...,xdr], [yd0,yd1,...ydr] where
|
||||
|
||||
xdk = D^k ( 1/a^s + 1/(a+1)^s + ... + 1/(a+n)^s )
|
||||
ydk = D^k conj( 1/a^(1-s) + 1/(a+1)^(1-s) + ... + 1/(a+n)^(1-s) )
|
||||
|
||||
D^k = kth derivative with respect to s, k ranges over the given list of
|
||||
derivatives (which should consist of either a single element
|
||||
or a range 0,1,...r). If reflect=False, the ydks are not computed.
|
||||
"""
|
||||
try:
|
||||
return ctx._zetasum_fast(s, a, n, derivatives, reflect)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
negs = ctx.fneg(s, exact=True)
|
||||
have_derivatives = derivatives != [0]
|
||||
have_one_derivative = len(derivatives) == 1
|
||||
if not reflect:
|
||||
if not have_derivatives:
|
||||
return [ctx.fsum((a+k)**negs for k in xrange(n+1))], []
|
||||
if have_one_derivative:
|
||||
d = derivatives[0]
|
||||
x = ctx.fsum(ctx.ln(a+k)**d * (a+k)**negs for k in xrange(n+1))
|
||||
return [(-1)**d * x], []
|
||||
maxd = max(derivatives)
|
||||
if not have_one_derivative:
|
||||
derivatives = range(maxd+1)
|
||||
xs = [ctx.zero for d in derivatives]
|
||||
if reflect:
|
||||
ys = [ctx.zero for d in derivatives]
|
||||
else:
|
||||
ys = []
|
||||
for k in xrange(n+1):
|
||||
w = a + k
|
||||
xterm = w ** negs
|
||||
if reflect:
|
||||
yterm = ctx.conj(ctx.one / (w * xterm))
|
||||
if have_derivatives:
|
||||
logw = -ctx.ln(w)
|
||||
if have_one_derivative:
|
||||
logw = logw ** maxd
|
||||
xs[0] += xterm * logw
|
||||
if reflect:
|
||||
ys[0] += yterm * logw
|
||||
else:
|
||||
t = ctx.one
|
||||
for d in derivatives:
|
||||
xs[d] += xterm * t
|
||||
if reflect:
|
||||
ys[d] += yterm * t
|
||||
t *= logw
|
||||
else:
|
||||
xs[0] += xterm
|
||||
if reflect:
|
||||
ys[0] += yterm
|
||||
return xs, ys
|
||||
|
||||
@defun
|
||||
def dirichlet(ctx, s, chi=[1], derivative=0):
|
||||
s = ctx.convert(s)
|
||||
q = len(chi)
|
||||
d = int(derivative)
|
||||
if d > 2:
|
||||
raise NotImplementedError("arbitrary order derivatives")
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
if s == 1:
|
||||
have_pole = True
|
||||
for x in chi:
|
||||
if x and x != 1:
|
||||
have_pole = False
|
||||
h = +ctx.eps
|
||||
ctx.prec *= 2*(d+1)
|
||||
s += h
|
||||
if have_pole:
|
||||
return +ctx.inf
|
||||
z = ctx.zero
|
||||
for p in range(1,q+1):
|
||||
if chi[p%q]:
|
||||
if d == 1:
|
||||
z += chi[p%q] * (ctx.zeta(s, (p,q), 1) - \
|
||||
ctx.zeta(s, (p,q))*ctx.log(q))
|
||||
else:
|
||||
z += chi[p%q] * ctx.zeta(s, (p,q))
|
||||
z /= q**s
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return +z
|
||||
|
|
@ -1,840 +0,0 @@
|
|||
"""
|
||||
Implements the PSLQ algorithm for integer relation detection,
|
||||
and derivative algorithms for constant recognition.
|
||||
"""
|
||||
|
||||
from libmp import int_types, sqrt_fixed
|
||||
|
||||
# round to nearest integer (can be done more elegantly...)
|
||||
def round_fixed(x, prec):
|
||||
return ((x + (1<<(prec-1))) >> prec) << prec
|
||||
|
||||
class IdentificationMethods(object):
|
||||
pass
|
||||
|
||||
|
||||
def pslq(ctx, x, tol=None, maxcoeff=1000, maxsteps=100, verbose=False):
|
||||
r"""
|
||||
Given a vector of real numbers `x = [x_0, x_1, ..., x_n]`, ``pslq(x)``
|
||||
uses the PSLQ algorithm to find a list of integers
|
||||
`[c_0, c_1, ..., c_n]` such that
|
||||
|
||||
.. math ::
|
||||
|
||||
|c_1 x_1 + c_2 x_2 + ... + c_n x_n| < \mathrm{tol}
|
||||
|
||||
and such that `\max |c_k| < \mathrm{maxcoeff}`. If no such vector
|
||||
exists, :func:`pslq` returns ``None``. The tolerance defaults to
|
||||
3/4 of the working precision.
|
||||
|
||||
**Examples**
|
||||
|
||||
Find rational approximations for `\pi`::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> pslq([-1, pi], tol=0.01)
|
||||
[22, 7]
|
||||
>>> pslq([-1, pi], tol=0.001)
|
||||
[355, 113]
|
||||
>>> mpf(22)/7; mpf(355)/113; +pi
|
||||
3.14285714285714
|
||||
3.14159292035398
|
||||
3.14159265358979
|
||||
|
||||
Pi is not a rational number with denominator less than 1000::
|
||||
|
||||
>>> pslq([-1, pi])
|
||||
>>>
|
||||
|
||||
To within the standard precision, it can however be approximated
|
||||
by at least one rational number with denominator less than `10^{12}`::
|
||||
|
||||
>>> p, q = pslq([-1, pi], maxcoeff=10**12)
|
||||
>>> print p, q
|
||||
238410049439 75888275702
|
||||
>>> mpf(p)/q
|
||||
3.14159265358979
|
||||
|
||||
The PSLQ algorithm can be applied to long vectors. For example,
|
||||
we can investigate the rational (in)dependence of integer square
|
||||
roots::
|
||||
|
||||
>>> mp.dps = 30
|
||||
>>> pslq([sqrt(n) for n in range(2, 5+1)])
|
||||
>>>
|
||||
>>> pslq([sqrt(n) for n in range(2, 6+1)])
|
||||
>>>
|
||||
>>> pslq([sqrt(n) for n in range(2, 8+1)])
|
||||
[2, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
**Machin formulas**
|
||||
|
||||
A famous formula for `\pi` is Machin's,
|
||||
|
||||
.. math ::
|
||||
|
||||
\frac{\pi}{4} = 4 \operatorname{acot} 5 - \operatorname{acot} 239
|
||||
|
||||
There are actually infinitely many formulas of this type. Two
|
||||
others are
|
||||
|
||||
.. math ::
|
||||
|
||||
\frac{\pi}{4} = \operatorname{acot} 1
|
||||
|
||||
\frac{\pi}{4} = 12 \operatorname{acot} 49 + 32 \operatorname{acot} 57
|
||||
+ 5 \operatorname{acot} 239 + 12 \operatorname{acot} 110443
|
||||
|
||||
We can easily verify the formulas using the PSLQ algorithm::
|
||||
|
||||
>>> mp.dps = 30
|
||||
>>> pslq([pi/4, acot(1)])
|
||||
[1, -1]
|
||||
>>> pslq([pi/4, acot(5), acot(239)])
|
||||
[1, -4, 1]
|
||||
>>> pslq([pi/4, acot(49), acot(57), acot(239), acot(110443)])
|
||||
[1, -12, -32, 5, -12]
|
||||
|
||||
We could try to generate a custom Machin-like formula by running
|
||||
the PSLQ algorithm with a few inverse cotangent values, for example
|
||||
acot(2), acot(3) ... acot(10). Unfortunately, there is a linear
|
||||
dependence among these values, resulting in only that dependence
|
||||
being detected, with a zero coefficient for `\pi`::
|
||||
|
||||
>>> pslq([pi] + [acot(n) for n in range(2,11)])
|
||||
[0, 1, -1, 0, 0, 0, -1, 0, 0, 0]
|
||||
|
||||
We get better luck by removing linearly dependent terms::
|
||||
|
||||
>>> pslq([pi] + [acot(n) for n in range(2,11) if n not in (3, 5)])
|
||||
[1, -8, 0, 0, 4, 0, 0, 0]
|
||||
|
||||
In other words, we found the following formula::
|
||||
|
||||
>>> 8*acot(2) - 4*acot(7)
|
||||
3.14159265358979323846264338328
|
||||
>>> +pi
|
||||
3.14159265358979323846264338328
|
||||
|
||||
**Algorithm**
|
||||
|
||||
This is a fairly direct translation to Python of the pseudocode given by
|
||||
David Bailey, "The PSLQ Integer Relation Algorithm":
|
||||
http://www.cecm.sfu.ca/organics/papers/bailey/paper/html/node3.html
|
||||
|
||||
The present implementation uses fixed-point instead of floating-point
|
||||
arithmetic, since this is significantly (about 7x) faster.
|
||||
"""
|
||||
|
||||
n = len(x)
|
||||
assert n >= 2
|
||||
|
||||
# At too low precision, the algorithm becomes meaningless
|
||||
prec = ctx.prec
|
||||
assert prec >= 53
|
||||
|
||||
if verbose and prec // max(2,n) < 5:
|
||||
print "Warning: precision for PSLQ may be too low"
|
||||
|
||||
target = int(prec * 0.75)
|
||||
|
||||
if tol is None:
|
||||
tol = ctx.mpf(2)**(-target)
|
||||
else:
|
||||
tol = ctx.convert(tol)
|
||||
|
||||
extra = 60
|
||||
prec += extra
|
||||
|
||||
if verbose:
|
||||
print "PSLQ using prec %i and tol %s" % (prec, ctx.nstr(tol))
|
||||
|
||||
tol = ctx.to_fixed(tol, prec)
|
||||
assert tol
|
||||
|
||||
# Convert to fixed-point numbers. The dummy None is added so we can
|
||||
# use 1-based indexing. (This just allows us to be consistent with
|
||||
# Bailey's indexing. The algorithm is 100 lines long, so debugging
|
||||
# a single wrong index can be painful.)
|
||||
x = [None] + [ctx.to_fixed(ctx.mpf(xk), prec) for xk in x]
|
||||
|
||||
# Sanity check on magnitudes
|
||||
minx = min(abs(xx) for xx in x[1:])
|
||||
if not minx:
|
||||
raise ValueError("PSLQ requires a vector of nonzero numbers")
|
||||
if minx < tol//100:
|
||||
if verbose:
|
||||
print "STOPPING: (one number is too small)"
|
||||
return None
|
||||
|
||||
g = sqrt_fixed((4<<prec)//3, prec)
|
||||
A = {}
|
||||
B = {}
|
||||
H = {}
|
||||
# Initialization
|
||||
# step 1
|
||||
for i in xrange(1, n+1):
|
||||
for j in xrange(1, n+1):
|
||||
A[i,j] = B[i,j] = (i==j) << prec
|
||||
H[i,j] = 0
|
||||
# step 2
|
||||
s = [None] + [0] * n
|
||||
for k in xrange(1, n+1):
|
||||
t = 0
|
||||
for j in xrange(k, n+1):
|
||||
t += (x[j]**2 >> prec)
|
||||
s[k] = sqrt_fixed(t, prec)
|
||||
t = s[1]
|
||||
y = x[:]
|
||||
for k in xrange(1, n+1):
|
||||
y[k] = (x[k] << prec) // t
|
||||
s[k] = (s[k] << prec) // t
|
||||
# step 3
|
||||
for i in xrange(1, n+1):
|
||||
for j in xrange(i+1, n):
|
||||
H[i,j] = 0
|
||||
if i <= n-1:
|
||||
if s[i]:
|
||||
H[i,i] = (s[i+1] << prec) // s[i]
|
||||
else:
|
||||
H[i,i] = 0
|
||||
for j in range(1, i):
|
||||
sjj1 = s[j]*s[j+1]
|
||||
if sjj1:
|
||||
H[i,j] = ((-y[i]*y[j])<<prec)//sjj1
|
||||
else:
|
||||
H[i,j] = 0
|
||||
# step 4
|
||||
for i in xrange(2, n+1):
|
||||
for j in xrange(i-1, 0, -1):
|
||||
#t = floor(H[i,j]/H[j,j] + 0.5)
|
||||
if H[j,j]:
|
||||
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
|
||||
else:
|
||||
#t = 0
|
||||
continue
|
||||
y[j] = y[j] + (t*y[i] >> prec)
|
||||
for k in xrange(1, j+1):
|
||||
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
|
||||
for k in xrange(1, n+1):
|
||||
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
|
||||
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
|
||||
# Main algorithm
|
||||
for REP in range(maxsteps):
|
||||
# Step 1
|
||||
m = -1
|
||||
szmax = -1
|
||||
for i in range(1, n):
|
||||
h = H[i,i]
|
||||
sz = (g**i * abs(h)) >> (prec*(i-1))
|
||||
if sz > szmax:
|
||||
m = i
|
||||
szmax = sz
|
||||
# Step 2
|
||||
y[m], y[m+1] = y[m+1], y[m]
|
||||
tmp = {}
|
||||
for i in xrange(1,n+1): H[m,i], H[m+1,i] = H[m+1,i], H[m,i]
|
||||
for i in xrange(1,n+1): A[m,i], A[m+1,i] = A[m+1,i], A[m,i]
|
||||
for i in xrange(1,n+1): B[i,m], B[i,m+1] = B[i,m+1], B[i,m]
|
||||
# Step 3
|
||||
if m <= n - 2:
|
||||
t0 = sqrt_fixed((H[m,m]**2 + H[m,m+1]**2)>>prec, prec)
|
||||
# A zero element probably indicates that the precision has
|
||||
# been exhausted. XXX: this could be spurious, due to
|
||||
# using fixed-point arithmetic
|
||||
if not t0:
|
||||
break
|
||||
t1 = (H[m,m] << prec) // t0
|
||||
t2 = (H[m,m+1] << prec) // t0
|
||||
for i in xrange(m, n+1):
|
||||
t3 = H[i,m]
|
||||
t4 = H[i,m+1]
|
||||
H[i,m] = (t1*t3+t2*t4) >> prec
|
||||
H[i,m+1] = (-t2*t3+t1*t4) >> prec
|
||||
# Step 4
|
||||
for i in xrange(m+1, n+1):
|
||||
for j in xrange(min(i-1, m+1), 0, -1):
|
||||
try:
|
||||
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
|
||||
# Precision probably exhausted
|
||||
except ZeroDivisionError:
|
||||
break
|
||||
y[j] = y[j] + ((t*y[i]) >> prec)
|
||||
for k in xrange(1, j+1):
|
||||
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
|
||||
for k in xrange(1, n+1):
|
||||
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
|
||||
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
|
||||
# Until a relation is found, the error typically decreases
|
||||
# slowly (e.g. a factor 1-10) with each step TODO: we could
|
||||
# compare err from two successive iterations. If there is a
|
||||
# large drop (several orders of magnitude), that indicates a
|
||||
# "high quality" relation was detected. Reporting this to
|
||||
# the user somehow might be useful.
|
||||
best_err = maxcoeff<<prec
|
||||
for i in xrange(1, n+1):
|
||||
err = abs(y[i])
|
||||
# Maybe we are done?
|
||||
if err < tol:
|
||||
# We are done if the coefficients are acceptable
|
||||
vec = [int(round_fixed(B[j,i], prec) >> prec) for j in \
|
||||
range(1,n+1)]
|
||||
if max(abs(v) for v in vec) < maxcoeff:
|
||||
if verbose:
|
||||
print "FOUND relation at iter %i/%i, error: %s" % \
|
||||
(REP, maxsteps, ctx.nstr(err / ctx.mpf(2)**prec, 1))
|
||||
return vec
|
||||
best_err = min(err, best_err)
|
||||
# Calculate a lower bound for the norm. We could do this
|
||||
# more exactly (using the Euclidean norm) but there is probably
|
||||
# no practical benefit.
|
||||
recnorm = max(abs(h) for h in H.values())
|
||||
if recnorm:
|
||||
norm = ((1 << (2*prec)) // recnorm) >> prec
|
||||
norm //= 100
|
||||
else:
|
||||
norm = ctx.inf
|
||||
if verbose:
|
||||
print "%i/%i: Error: %8s Norm: %s" % \
|
||||
(REP, maxsteps, ctx.nstr(best_err / ctx.mpf(2)**prec, 1), norm)
|
||||
if norm >= maxcoeff:
|
||||
break
|
||||
if verbose:
|
||||
print "CANCELLING after step %i/%i." % (REP, maxsteps)
|
||||
print "Could not find an integer relation. Norm bound: %s" % norm
|
||||
return None
|
||||
|
||||
def findpoly(ctx, x, n=1, **kwargs):
|
||||
r"""
|
||||
``findpoly(x, n)`` returns the coefficients of an integer
|
||||
polynomial `P` of degree at most `n` such that `P(x) \approx 0`.
|
||||
If no polynomial having `x` as a root can be found,
|
||||
:func:`findpoly` returns ``None``.
|
||||
|
||||
:func:`findpoly` works by successively calling :func:`pslq` with
|
||||
the vectors `[1, x]`, `[1, x, x^2]`, `[1, x, x^2, x^3]`, ...,
|
||||
`[1, x, x^2, .., x^n]` as input. Keyword arguments given to
|
||||
:func:`findpoly` are forwarded verbatim to :func:`pslq`. In
|
||||
particular, you can specify a tolerance for `P(x)` with ``tol``
|
||||
and a maximum permitted coefficient size with ``maxcoeff``.
|
||||
|
||||
For large values of `n`, it is recommended to run :func:`findpoly`
|
||||
at high precision; preferably 50 digits or more.
|
||||
|
||||
**Examples**
|
||||
|
||||
By default (degree `n = 1`), :func:`findpoly` simply finds a linear
|
||||
polynomial with a rational root::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> findpoly(0.7)
|
||||
[-10, 7]
|
||||
|
||||
The generated coefficient list is valid input to ``polyval`` and
|
||||
``polyroots``::
|
||||
|
||||
>>> nprint(polyval(findpoly(phi, 2), phi), 1)
|
||||
-2.0e-16
|
||||
>>> for r in polyroots(findpoly(phi, 2)):
|
||||
... print r
|
||||
...
|
||||
-0.618033988749895
|
||||
1.61803398874989
|
||||
|
||||
Numbers of the form `m + n \sqrt p` for integers `(m, n, p)` are
|
||||
solutions to quadratic equations. As we find here, `1+\sqrt 2`
|
||||
is a root of the polynomial `x^2 - 2x - 1`::
|
||||
|
||||
>>> findpoly(1+sqrt(2), 2)
|
||||
[1, -2, -1]
|
||||
>>> findroot(lambda x: x**2 - 2*x - 1, 1)
|
||||
2.4142135623731
|
||||
|
||||
Despite only containing square roots, the following number results
|
||||
in a polynomial of degree 4::
|
||||
|
||||
>>> findpoly(sqrt(2)+sqrt(3), 4)
|
||||
[1, 0, -10, 0, 1]
|
||||
|
||||
In fact, `x^4 - 10x^2 + 1` is the *minimal polynomial* of
|
||||
`r = \sqrt 2 + \sqrt 3`, meaning that a rational polynomial of
|
||||
lower degree having `r` as a root does not exist. Given sufficient
|
||||
precision, :func:`findpoly` will usually find the correct
|
||||
minimal polynomial of a given algebraic number.
|
||||
|
||||
**Non-algebraic numbers**
|
||||
|
||||
If :func:`findpoly` fails to find a polynomial with given
|
||||
coefficient size and tolerance constraints, that means no such
|
||||
polynomial exists.
|
||||
|
||||
We can verify that `\pi` is not an algebraic number of degree 3 with
|
||||
coefficients less than 1000::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> findpoly(pi, 3)
|
||||
>>>
|
||||
|
||||
It is always possible to find an algebraic approximation of a number
|
||||
using one (or several) of the following methods:
|
||||
|
||||
1. Increasing the permitted degree
|
||||
2. Allowing larger coefficients
|
||||
3. Reducing the tolerance
|
||||
|
||||
One example of each method is shown below::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> findpoly(pi, 4)
|
||||
[95, -545, 863, -183, -298]
|
||||
>>> findpoly(pi, 3, maxcoeff=10000)
|
||||
[836, -1734, -2658, -457]
|
||||
>>> findpoly(pi, 3, tol=1e-7)
|
||||
[-4, 22, -29, -2]
|
||||
|
||||
It is unknown whether Euler's constant is transcendental (or even
|
||||
irrational). We can use :func:`findpoly` to check that if is
|
||||
an algebraic number, its minimal polynomial must have degree
|
||||
at least 7 and a coefficient of magnitude at least 1000000::
|
||||
|
||||
>>> mp.dps = 200
|
||||
>>> findpoly(euler, 6, maxcoeff=10**6, tol=1e-100, maxsteps=1000)
|
||||
>>>
|
||||
|
||||
Note that the high precision and strict tolerance is necessary
|
||||
for such high-degree runs, since otherwise unwanted low-accuracy
|
||||
approximations will be detected. It may also be necessary to set
|
||||
maxsteps high to prevent a premature exit (before the coefficient
|
||||
bound has been reached). Running with ``verbose=True`` to get an
|
||||
idea what is happening can be useful.
|
||||
"""
|
||||
x = ctx.mpf(x)
|
||||
assert n >= 1
|
||||
if x == 0:
|
||||
return [1, 0]
|
||||
xs = [ctx.mpf(1)]
|
||||
for i in range(1,n+1):
|
||||
xs.append(x**i)
|
||||
a = ctx.pslq(xs, **kwargs)
|
||||
if a is not None:
|
||||
return a[::-1]
|
||||
|
||||
def fracgcd(p, q):
|
||||
x, y = p, q
|
||||
while y:
|
||||
x, y = y, x % y
|
||||
if x != 1:
|
||||
p //= x
|
||||
q //= x
|
||||
if q == 1:
|
||||
return p
|
||||
return p, q
|
||||
|
||||
def pslqstring(r, constants):
|
||||
q = r[0]
|
||||
r = r[1:]
|
||||
s = []
|
||||
for i in range(len(r)):
|
||||
p = r[i]
|
||||
if p:
|
||||
z = fracgcd(-p,q)
|
||||
cs = constants[i][1]
|
||||
if cs == '1':
|
||||
cs = ''
|
||||
else:
|
||||
cs = '*' + cs
|
||||
if isinstance(z, int_types):
|
||||
if z > 0: term = str(z) + cs
|
||||
else: term = ("(%s)" % z) + cs
|
||||
else:
|
||||
term = ("(%s/%s)" % z) + cs
|
||||
s.append(term)
|
||||
s = ' + '.join(s)
|
||||
if '+' in s or '*' in s:
|
||||
s = '(' + s + ')'
|
||||
return s or '0'
|
||||
|
||||
def prodstring(r, constants):
|
||||
q = r[0]
|
||||
r = r[1:]
|
||||
num = []
|
||||
den = []
|
||||
for i in range(len(r)):
|
||||
p = r[i]
|
||||
if p:
|
||||
z = fracgcd(-p,q)
|
||||
cs = constants[i][1]
|
||||
if isinstance(z, int_types):
|
||||
if abs(z) == 1: t = cs
|
||||
else: t = '%s**%s' % (cs, abs(z))
|
||||
([num,den][z<0]).append(t)
|
||||
else:
|
||||
t = '%s**(%s/%s)' % (cs, abs(z[0]), z[1])
|
||||
([num,den][z[0]<0]).append(t)
|
||||
num = '*'.join(num)
|
||||
den = '*'.join(den)
|
||||
if num and den: return "(%s)/(%s)" % (num, den)
|
||||
if num: return num
|
||||
if den: return "1/(%s)" % den
|
||||
|
||||
def quadraticstring(ctx,t,a,b,c):
|
||||
if c < 0:
|
||||
a,b,c = -a,-b,-c
|
||||
u1 = (-b+ctx.sqrt(b**2-4*a*c))/(2*c)
|
||||
u2 = (-b-ctx.sqrt(b**2-4*a*c))/(2*c)
|
||||
if abs(u1-t) < abs(u2-t):
|
||||
if b: s = '((%s+sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
|
||||
else: s = '(sqrt(%s)/%s)' % (-4*a*c,2*c)
|
||||
else:
|
||||
if b: s = '((%s-sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
|
||||
else: s = '(-sqrt(%s)/%s)' % (-4*a*c,2*c)
|
||||
return s
|
||||
|
||||
# Transformation y = f(x,c), with inverse function x = f(y,c)
|
||||
# The third entry indicates whether the transformation is
|
||||
# redundant when c = 1
|
||||
transforms = [
|
||||
(lambda ctx,x,c: x*c, '$y/$c', 0),
|
||||
(lambda ctx,x,c: x/c, '$c*$y', 1),
|
||||
(lambda ctx,x,c: c/x, '$c/$y', 0),
|
||||
(lambda ctx,x,c: (x*c)**2, 'sqrt($y)/$c', 0),
|
||||
(lambda ctx,x,c: (x/c)**2, '$c*sqrt($y)', 1),
|
||||
(lambda ctx,x,c: (c/x)**2, '$c/sqrt($y)', 0),
|
||||
(lambda ctx,x,c: c*x**2, 'sqrt($y)/sqrt($c)', 1),
|
||||
(lambda ctx,x,c: x**2/c, 'sqrt($c)*sqrt($y)', 1),
|
||||
(lambda ctx,x,c: c/x**2, 'sqrt($c)/sqrt($y)', 1),
|
||||
(lambda ctx,x,c: ctx.sqrt(x*c), '$y**2/$c', 0),
|
||||
(lambda ctx,x,c: ctx.sqrt(x/c), '$c*$y**2', 1),
|
||||
(lambda ctx,x,c: ctx.sqrt(c/x), '$c/$y**2', 0),
|
||||
(lambda ctx,x,c: c*ctx.sqrt(x), '$y**2/$c**2', 1),
|
||||
(lambda ctx,x,c: ctx.sqrt(x)/c, '$c**2*$y**2', 1),
|
||||
(lambda ctx,x,c: c/ctx.sqrt(x), '$c**2/$y**2', 1),
|
||||
(lambda ctx,x,c: ctx.exp(x*c), 'log($y)/$c', 0),
|
||||
(lambda ctx,x,c: ctx.exp(x/c), '$c*log($y)', 1),
|
||||
(lambda ctx,x,c: ctx.exp(c/x), '$c/log($y)', 0),
|
||||
(lambda ctx,x,c: c*ctx.exp(x), 'log($y/$c)', 1),
|
||||
(lambda ctx,x,c: ctx.exp(x)/c, 'log($c*$y)', 1),
|
||||
(lambda ctx,x,c: c/ctx.exp(x), 'log($c/$y)', 0),
|
||||
(lambda ctx,x,c: ctx.ln(x*c), 'exp($y)/$c', 0),
|
||||
(lambda ctx,x,c: ctx.ln(x/c), '$c*exp($y)', 1),
|
||||
(lambda ctx,x,c: ctx.ln(c/x), '$c/exp($y)', 0),
|
||||
(lambda ctx,x,c: c*ctx.ln(x), 'exp($y/$c)', 1),
|
||||
(lambda ctx,x,c: ctx.ln(x)/c, 'exp($c*$y)', 1),
|
||||
(lambda ctx,x,c: c/ctx.ln(x), 'exp($c/$y)', 0),
|
||||
]
|
||||
|
||||
def identify(ctx, x, constants=[], tol=None, maxcoeff=1000, full=False,
|
||||
verbose=False):
|
||||
"""
|
||||
Given a real number `x`, ``identify(x)`` attempts to find an exact
|
||||
formula for `x`. This formula is returned as a string. If no match
|
||||
is found, ``None`` is returned. With ``full=True``, a list of
|
||||
matching formulas is returned.
|
||||
|
||||
As a simple example, :func:`identify` will find an algebraic
|
||||
formula for the golden ratio::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> identify(phi)
|
||||
'((1+sqrt(5))/2)'
|
||||
|
||||
:func:`identify` can identify simple algebraic numbers and simple
|
||||
combinations of given base constants, as well as certain basic
|
||||
transformations thereof. More specifically, :func:`identify`
|
||||
looks for the following:
|
||||
|
||||
1. Fractions
|
||||
2. Quadratic algebraic numbers
|
||||
3. Rational linear combinations of the base constants
|
||||
4. Any of the above after first transforming `x` into `f(x)` where
|
||||
`f(x)` is `1/x`, `\sqrt x`, `x^2`, `\log x` or `\exp x`, either
|
||||
directly or with `x` or `f(x)` multiplied or divided by one of
|
||||
the base constants
|
||||
5. Products of fractional powers of the base constants and
|
||||
small integers
|
||||
|
||||
Base constants can be given as a list of strings representing mpmath
|
||||
expressions (:func:`identify` will ``eval`` the strings to numerical
|
||||
values and use the original strings for the output), or as a dict of
|
||||
formula:value pairs.
|
||||
|
||||
In order not to produce spurious results, :func:`identify` should
|
||||
be used with high precision; preferrably 50 digits or more.
|
||||
|
||||
**Examples**
|
||||
|
||||
Simple identifications can be performed safely at standard
|
||||
precision. Here the default recognition of rational, algebraic,
|
||||
and exp/log of algebraic numbers is demonstrated::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> identify(0.22222222222222222)
|
||||
'(2/9)'
|
||||
>>> identify(1.9662210973805663)
|
||||
'sqrt(((24+sqrt(48))/8))'
|
||||
>>> identify(4.1132503787829275)
|
||||
'exp((sqrt(8)/2))'
|
||||
>>> identify(0.881373587019543)
|
||||
'log(((2+sqrt(8))/2))'
|
||||
|
||||
By default, :func:`identify` does not recognize `\pi`. At standard
|
||||
precision it finds a not too useful approximation. At slightly
|
||||
increased precision, this approximation is no longer accurate
|
||||
enough and :func:`identify` more correctly returns ``None``::
|
||||
|
||||
>>> identify(pi)
|
||||
'(2**(176/117)*3**(20/117)*5**(35/39))/(7**(92/117))'
|
||||
>>> mp.dps = 30
|
||||
>>> identify(pi)
|
||||
>>>
|
||||
|
||||
Numbers such as `\pi`, and simple combinations of user-defined
|
||||
constants, can be identified if they are provided explicitly::
|
||||
|
||||
>>> identify(3*pi-2*e, ['pi', 'e'])
|
||||
'(3*pi + (-2)*e)'
|
||||
|
||||
Here is an example using a dict of constants. Note that the
|
||||
constants need not be "atomic"; :func:`identify` can just
|
||||
as well express the given number in terms of expressions
|
||||
given by formulas::
|
||||
|
||||
>>> identify(pi+e, {'a':pi+2, 'b':2*e})
|
||||
'((-2) + 1*a + (1/2)*b)'
|
||||
|
||||
Next, we attempt some identifications with a set of base constants.
|
||||
It is necessary to increase the precision a bit.
|
||||
|
||||
>>> mp.dps = 50
|
||||
>>> base = ['sqrt(2)','pi','log(2)']
|
||||
>>> identify(0.25, base)
|
||||
'(1/4)'
|
||||
>>> identify(3*pi + 2*sqrt(2) + 5*log(2)/7, base)
|
||||
'(2*sqrt(2) + 3*pi + (5/7)*log(2))'
|
||||
>>> identify(exp(pi+2), base)
|
||||
'exp((2 + 1*pi))'
|
||||
>>> identify(1/(3+sqrt(2)), base)
|
||||
'((3/7) + (-1/7)*sqrt(2))'
|
||||
>>> identify(sqrt(2)/(3*pi+4), base)
|
||||
'sqrt(2)/(4 + 3*pi)'
|
||||
>>> identify(5**(mpf(1)/3)*pi*log(2)**2, base)
|
||||
'5**(1/3)*pi*log(2)**2'
|
||||
|
||||
An example of an erroneous solution being found when too low
|
||||
precision is used::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
|
||||
'((11/25) + (-158/75)*pi + (76/75)*e + (44/15)*sqrt(2))'
|
||||
>>> mp.dps = 50
|
||||
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
|
||||
'1/(3*pi + (-4)*e + 2*sqrt(2))'
|
||||
|
||||
**Finding approximate solutions**
|
||||
|
||||
The tolerance ``tol`` defaults to 3/4 of the working precision.
|
||||
Lowering the tolerance is useful for finding approximate matches.
|
||||
We can for example try to generate approximations for pi::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> identify(pi, tol=1e-2)
|
||||
'(22/7)'
|
||||
>>> identify(pi, tol=1e-3)
|
||||
'(355/113)'
|
||||
>>> identify(pi, tol=1e-10)
|
||||
'(5**(339/269))/(2**(64/269)*3**(13/269)*7**(92/269))'
|
||||
|
||||
With ``full=True``, and by supplying a few base constants,
|
||||
``identify`` can generate almost endless lists of approximations
|
||||
for any number (the output below has been truncated to show only
|
||||
the first few)::
|
||||
|
||||
>>> for p in identify(pi, ['e', 'catalan'], tol=1e-5, full=True):
|
||||
... print p
|
||||
... # doctest: +ELLIPSIS
|
||||
e/log((6 + (-4/3)*e))
|
||||
(3**3*5*e*catalan**2)/(2*7**2)
|
||||
sqrt(((-13) + 1*e + 22*catalan))
|
||||
log(((-6) + 24*e + 4*catalan)/e)
|
||||
exp(catalan*((-1/5) + (8/15)*e))
|
||||
catalan*(6 + (-6)*e + 15*catalan)
|
||||
sqrt((5 + 26*e + (-3)*catalan))/e
|
||||
e*sqrt(((-27) + 2*e + 25*catalan))
|
||||
log(((-1) + (-11)*e + 59*catalan))
|
||||
((3/20) + (21/20)*e + (3/20)*catalan)
|
||||
...
|
||||
|
||||
The numerical values are roughly as close to pi as permitted by the
|
||||
specified tolerance:
|
||||
|
||||
>>> e/log(6-4*e/3)
|
||||
3.14157719846001
|
||||
>>> 135*e*catalan**2/98
|
||||
3.14166950419369
|
||||
>>> sqrt(e-13+22*catalan)
|
||||
3.14158000062992
|
||||
>>> log(24*e-6+4*catalan)-1
|
||||
3.14158791577159
|
||||
|
||||
**Symbolic processing**
|
||||
|
||||
The output formula can be evaluated as a Python expression.
|
||||
Note however that if fractions (like '2/3') are present in
|
||||
the formula, Python's :func:`eval()` may erroneously perform
|
||||
integer division. Note also that the output is not necessarily
|
||||
in the algebraically simplest form::
|
||||
|
||||
>>> identify(sqrt(2))
|
||||
'(sqrt(8)/2)'
|
||||
|
||||
As a solution to both problems, consider using SymPy's
|
||||
:func:`sympify` to convert the formula into a symbolic expression.
|
||||
SymPy can be used to pretty-print or further simplify the formula
|
||||
symbolically::
|
||||
|
||||
>>> from sympy import sympify
|
||||
>>> sympify(identify(sqrt(2)))
|
||||
2**(1/2)
|
||||
|
||||
Sometimes :func:`identify` can simplify an expression further than
|
||||
a symbolic algorithm::
|
||||
|
||||
>>> from sympy import simplify
|
||||
>>> x = sympify('-1/(-3/2+(1/2)*5**(1/2))*(3/2-1/2*5**(1/2))**(1/2)')
|
||||
>>> x
|
||||
(3/2 - 5**(1/2)/2)**(-1/2)
|
||||
>>> x = simplify(x)
|
||||
>>> x
|
||||
2/(6 - 2*5**(1/2))**(1/2)
|
||||
>>> mp.dps = 30
|
||||
>>> x = sympify(identify(x.evalf(30)))
|
||||
>>> x
|
||||
1/2 + 5**(1/2)/2
|
||||
|
||||
(In fact, this functionality is available directly in SymPy as the
|
||||
function :func:`nsimplify`, which is essentially a wrapper for
|
||||
:func:`identify`.)
|
||||
|
||||
**Miscellaneous issues and limitations**
|
||||
|
||||
The input `x` must be a real number. All base constants must be
|
||||
positive real numbers and must not be rationals or rational linear
|
||||
combinations of each other.
|
||||
|
||||
The worst-case computation time grows quickly with the number of
|
||||
base constants. Already with 3 or 4 base constants,
|
||||
:func:`identify` may require several seconds to finish. To search
|
||||
for relations among a large number of constants, you should
|
||||
consider using :func:`pslq` directly.
|
||||
|
||||
The extended transformations are applied to x, not the constants
|
||||
separately. As a result, ``identify`` will for example be able to
|
||||
recognize ``exp(2*pi+3)`` with ``pi`` given as a base constant, but
|
||||
not ``2*exp(pi)+3``. It will be able to recognize the latter if
|
||||
``exp(pi)`` is given explicitly as a base constant.
|
||||
|
||||
"""
|
||||
|
||||
solutions = []
|
||||
|
||||
def addsolution(s):
|
||||
if verbose: print "Found: ", s
|
||||
solutions.append(s)
|
||||
|
||||
x = ctx.mpf(x)
|
||||
|
||||
# Further along, x will be assumed positive
|
||||
if x == 0:
|
||||
if full: return ['0']
|
||||
else: return '0'
|
||||
if x < 0:
|
||||
sol = ctx.identify(-x, constants, tol, maxcoeff, full, verbose)
|
||||
if sol is None:
|
||||
return sol
|
||||
if full:
|
||||
return ["-(%s)"%s for s in sol]
|
||||
else:
|
||||
return "-(%s)" % sol
|
||||
|
||||
if tol:
|
||||
tol = ctx.mpf(tol)
|
||||
else:
|
||||
tol = ctx.eps**0.7
|
||||
M = maxcoeff
|
||||
|
||||
if constants:
|
||||
if isinstance(constants, dict):
|
||||
constants = [(ctx.mpf(v), name) for (name, v) in constants.items()]
|
||||
else:
|
||||
namespace = dict((name, getattr(ctx,name)) for name in dir(ctx))
|
||||
constants = [(eval(p, namespace), p) for p in constants]
|
||||
else:
|
||||
constants = []
|
||||
|
||||
# We always want to find at least rational terms
|
||||
if 1 not in [value for (name, value) in constants]:
|
||||
constants = [(ctx.mpf(1), '1')] + constants
|
||||
|
||||
# PSLQ with simple algebraic and functional transformations
|
||||
for ft, ftn, red in transforms:
|
||||
for c, cn in constants:
|
||||
if red and cn == '1':
|
||||
continue
|
||||
t = ft(ctx,x,c)
|
||||
# Prevent exponential transforms from wreaking havoc
|
||||
if abs(t) > M**2 or abs(t) < tol:
|
||||
continue
|
||||
# Linear combination of base constants
|
||||
r = ctx.pslq([t] + [a[0] for a in constants], tol, M)
|
||||
s = None
|
||||
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
|
||||
s = pslqstring(r, constants)
|
||||
# Quadratic algebraic numbers
|
||||
else:
|
||||
q = ctx.pslq([ctx.one, t, t**2], tol, M)
|
||||
if q is not None and len(q) == 3 and q[2]:
|
||||
aa, bb, cc = q
|
||||
if max(abs(aa),abs(bb),abs(cc)) <= M:
|
||||
s = quadraticstring(ctx,t,aa,bb,cc)
|
||||
if s:
|
||||
if cn == '1' and ('/$c' in ftn):
|
||||
s = ftn.replace('$y', s).replace('/$c', '')
|
||||
else:
|
||||
s = ftn.replace('$y', s).replace('$c', cn)
|
||||
addsolution(s)
|
||||
if not full: return solutions[0]
|
||||
|
||||
if verbose:
|
||||
print "."
|
||||
|
||||
# Check for a direct multiplicative formula
|
||||
if x != 1:
|
||||
# Allow fractional powers of fractions
|
||||
ilogs = [2,3,5,7]
|
||||
# Watch out for existing fractional powers of fractions
|
||||
logs = []
|
||||
for a, s in constants:
|
||||
if not sum(bool(ctx.findpoly(ctx.ln(a)/ctx.ln(i),1)) for i in ilogs):
|
||||
logs.append((ctx.ln(a), s))
|
||||
logs = [(ctx.ln(i),str(i)) for i in ilogs] + logs
|
||||
r = ctx.pslq([ctx.ln(x)] + [a[0] for a in logs], tol, M)
|
||||
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
|
||||
addsolution(prodstring(r, logs))
|
||||
if not full: return solutions[0]
|
||||
|
||||
if full:
|
||||
return sorted(solutions, key=len)
|
||||
else:
|
||||
return None
|
||||
|
||||
IdentificationMethods.pslq = pslq
|
||||
IdentificationMethods.findpoly = findpoly
|
||||
IdentificationMethods.identify = identify
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
from libmpf import (prec_to_dps, dps_to_prec, repr_dps,
|
||||
round_down, round_up, round_floor, round_ceiling, round_nearest,
|
||||
to_pickable, from_pickable, ComplexResult,
|
||||
fzero, fnzero, fone, fnone, ftwo, ften, fhalf, fnan, finf, fninf,
|
||||
math_float_inf, round_int, normalize, normalize1,
|
||||
from_man_exp, from_int, to_man_exp, to_int, mpf_ceil, mpf_floor,
|
||||
from_float, to_float, from_rational, to_rational, to_fixed,
|
||||
mpf_rand, mpf_eq, mpf_hash, mpf_cmp, mpf_lt, mpf_le, mpf_gt, mpf_ge,
|
||||
mpf_pos, mpf_neg, mpf_abs, mpf_sign, mpf_add, mpf_sub, mpf_sum,
|
||||
mpf_mul, mpf_mul_int, mpf_shift, mpf_frexp,
|
||||
mpf_div, mpf_rdiv_int, mpf_mod, mpf_pow_int,
|
||||
mpf_perturb,
|
||||
to_digits_exp, to_str, str_to_man_exp, from_str, from_bstr, to_bstr,
|
||||
mpf_sqrt, mpf_hypot)
|
||||
|
||||
from libmpc import (mpc_one, mpc_zero, mpc_two, mpc_half,
|
||||
mpc_is_inf, mpc_is_infnan, mpc_to_str, mpc_to_complex, mpc_hash,
|
||||
mpc_conjugate, mpc_is_nonzero, mpc_add, mpc_add_mpf,
|
||||
mpc_sub, mpc_sub_mpf, mpc_pos, mpc_neg, mpc_shift, mpc_abs,
|
||||
mpc_arg, mpc_floor, mpc_ceil, mpc_mul, mpc_square,
|
||||
mpc_mul_mpf, mpc_mul_imag_mpf, mpc_mul_int,
|
||||
mpc_div, mpc_div_mpf, mpc_reciprocal, mpc_mpf_div,
|
||||
complex_int_pow, mpc_pow, mpc_pow_mpf, mpc_pow_int,
|
||||
mpc_sqrt, mpc_nthroot, mpc_cbrt, mpc_exp, mpc_log, mpc_cos, mpc_sin,
|
||||
mpc_tan, mpc_cos_pi, mpc_sin_pi, mpc_cosh, mpc_sinh, mpc_tanh,
|
||||
mpc_atan, mpc_acos, mpc_asin, mpc_asinh, mpc_acosh, mpc_atanh,
|
||||
mpc_fibonacci, mpf_expj, mpf_expjpi, mpc_expj, mpc_expjpi)
|
||||
|
||||
from libelefun import (ln2_fixed, mpf_ln2, ln10_fixed, mpf_ln10,
|
||||
pi_fixed, mpf_pi, e_fixed, mpf_e, phi_fixed, mpf_phi,
|
||||
degree_fixed, mpf_degree,
|
||||
mpf_pow, mpf_nthroot, mpf_cbrt, log_int_fixed, agm_fixed,
|
||||
mpf_log, mpf_log_hypot, mpf_exp, mpf_cos_sin, mpf_cos, mpf_sin, mpf_tan,
|
||||
mpf_cos_sin_pi, mpf_cos_pi, mpf_sin_pi, mpf_cosh_sinh,
|
||||
mpf_cosh, mpf_sinh, mpf_tanh, mpf_atan, mpf_atan2, mpf_asin,
|
||||
mpf_acos, mpf_asinh, mpf_acosh, mpf_atanh, mpf_fibonacci)
|
||||
|
||||
from libhyper import (NoConvergence, make_hyp_summator,
|
||||
mpf_erf, mpf_erfc, mpf_ei, mpc_ei, mpf_e1, mpc_e1, mpf_expint,
|
||||
mpf_ci_si, mpf_ci, mpf_si, mpc_ci, mpc_si, mpf_besseljn,
|
||||
mpc_besseljn, mpf_agm, mpf_agm1, mpc_agm, mpc_agm1,
|
||||
mpf_ellipk, mpc_ellipk, mpf_ellipe, mpc_ellipe)
|
||||
|
||||
from gammazeta import (catalan_fixed, mpf_catalan,
|
||||
khinchin_fixed, mpf_khinchin, glaisher_fixed, mpf_glaisher,
|
||||
apery_fixed, mpf_apery, euler_fixed, mpf_euler, mertens_fixed,
|
||||
mpf_mertens, twinprime_fixed, mpf_twinprime,
|
||||
mpf_bernoulli, bernfrac, mpf_gamma_int,
|
||||
mpf_factorial, mpc_factorial, mpf_gamma, mpc_gamma,
|
||||
mpf_harmonic, mpc_harmonic, mpf_psi0, mpc_psi0,
|
||||
mpf_psi, mpc_psi, mpf_zeta_int, mpf_zeta, mpc_zeta,
|
||||
mpf_altzeta, mpc_altzeta, mpf_zetasum, mpc_zetasum)
|
||||
|
||||
from libmpi import (mpi_str, mpi_add, mpi_sub, mpi_delta, mpi_mid,
|
||||
mpi_pos, mpi_neg, mpi_abs, mpi_mul, mpi_div, mpi_exp,
|
||||
mpi_log, mpi_sqrt, mpi_pow_int, mpi_pow, mpi_cos_sin,
|
||||
mpi_cos, mpi_sin, mpi_tan, mpi_cot)
|
||||
|
||||
from libintmath import (trailing, bitcount, numeral, bin_to_radix,
|
||||
isqrt, isqrt_small, isqrt_fast, sqrt_fixed, sqrtrem, ifib, ifac,
|
||||
list_primes, moebius, gcd, eulernum)
|
||||
|
||||
from backend import (gmpy, sage, BACKEND, STRICT, MPZ, MPZ_TYPE,
|
||||
MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_THREE, MPZ_FIVE, int_types)
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
#----------------------------------------------------------------------------#
|
||||
# Support GMPY for high-speed large integer arithmetic. #
|
||||
# #
|
||||
# To allow an external module to handle arithmetic, we need to make sure #
|
||||
# that all high-precision variables are declared of the correct type. MPZ #
|
||||
# is the constructor for the high-precision type. It defaults to Python's #
|
||||
# long type but can be assinged another type, typically gmpy.mpz. #
|
||||
# #
|
||||
# MPZ must be used for the mantissa component of an mpf and must be used #
|
||||
# for internal fixed-point operations. #
|
||||
# #
|
||||
# Side-effects #
|
||||
# 1) "is" cannot be used to test for special values. Must use "==". #
|
||||
# 2) There are bugs in GMPY prior to v1.02 so we must use v1.03 or later. #
|
||||
#----------------------------------------------------------------------------#
|
||||
|
||||
# So we can import it from this module
|
||||
gmpy = None
|
||||
sage = None
|
||||
sage_utils = None
|
||||
|
||||
BACKEND = 'python'
|
||||
MPZ = long
|
||||
|
||||
if 'MPMATH_NOGMPY' not in os.environ:
|
||||
try:
|
||||
import gmpy
|
||||
if gmpy.version() >= '1.03':
|
||||
BACKEND = 'gmpy'
|
||||
MPZ = gmpy.mpz
|
||||
except:
|
||||
pass
|
||||
|
||||
if 'MPMATH_NOSAGE' not in os.environ:
|
||||
try:
|
||||
import sage.all
|
||||
import sage.libs.mpmath.utils as _sage_utils
|
||||
sage = sage.all
|
||||
sage_utils = _sage_utils
|
||||
BACKEND = 'sage'
|
||||
MPZ = sage.Integer
|
||||
except:
|
||||
pass
|
||||
|
||||
if 'MPMATH_STRICT' in os.environ:
|
||||
STRICT = True
|
||||
else:
|
||||
STRICT = False
|
||||
|
||||
MPZ_TYPE = type(MPZ(0))
|
||||
MPZ_ZERO = MPZ(0)
|
||||
MPZ_ONE = MPZ(1)
|
||||
MPZ_TWO = MPZ(2)
|
||||
MPZ_THREE = MPZ(3)
|
||||
MPZ_FIVE = MPZ(5)
|
||||
|
||||
if BACKEND == 'python':
|
||||
int_types = (int, long)
|
||||
else:
|
||||
int_types = (int, long, MPZ_TYPE)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,461 +0,0 @@
|
|||
"""
|
||||
Utility functions for integer math.
|
||||
|
||||
TODO: rename, cleanup, perhaps move the gmpy wrapper code
|
||||
here from settings.py
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from bisect import bisect
|
||||
|
||||
from backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO
|
||||
|
||||
def giant_steps(start, target, n=2):
|
||||
"""
|
||||
Return a list of integers ~=
|
||||
|
||||
[start, n*start, ..., target/n^2, target/n, target]
|
||||
|
||||
but conservatively rounded so that the quotient between two
|
||||
successive elements is actually slightly less than n.
|
||||
|
||||
With n = 2, this describes suitable precision steps for a
|
||||
quadratically convergent algorithm such as Newton's method;
|
||||
with n = 3 steps for cubic convergence (Halley's method), etc.
|
||||
|
||||
>>> giant_steps(50,1000)
|
||||
[66, 128, 253, 502, 1000]
|
||||
>>> giant_steps(50,1000,4)
|
||||
[65, 252, 1000]
|
||||
|
||||
"""
|
||||
L = [target]
|
||||
while L[-1] > start*n:
|
||||
L = L + [L[-1]//n + 2]
|
||||
return L[::-1]
|
||||
|
||||
def rshift(x, n):
|
||||
"""For an integer x, calculate x >> n with the fastest (floor)
|
||||
rounding. Unlike the plain Python expression (x >> n), n is
|
||||
allowed to be negative, in which case a left shift is performed."""
|
||||
if n >= 0: return x >> n
|
||||
else: return x << (-n)
|
||||
|
||||
def lshift(x, n):
|
||||
"""For an integer x, calculate x << n. Unlike the plain Python
|
||||
expression (x << n), n is allowed to be negative, in which case a
|
||||
right shift with default (floor) rounding is performed."""
|
||||
if n >= 0: return x << n
|
||||
else: return x >> (-n)
|
||||
|
||||
if BACKEND == 'sage':
|
||||
import operator
|
||||
rshift = operator.rshift
|
||||
lshift = operator.lshift
|
||||
|
||||
def python_trailing(n):
|
||||
"""Count the number of trailing zero bits in abs(n)."""
|
||||
if not n:
|
||||
return 0
|
||||
t = 0
|
||||
while not n & 1:
|
||||
n >>= 1
|
||||
t += 1
|
||||
return t
|
||||
|
||||
def gmpy_trailing(n):
|
||||
"""Count the number of trailing zero bits in abs(n) using gmpy."""
|
||||
if n: return MPZ(n).scan1()
|
||||
else: return 0
|
||||
|
||||
# Small powers of 2
|
||||
powers = [1<<_ for _ in range(300)]
|
||||
|
||||
def python_bitcount(n):
|
||||
"""Calculate bit size of the nonnegative integer n."""
|
||||
bc = bisect(powers, n)
|
||||
if bc != 300:
|
||||
return bc
|
||||
bc = int(math.log(n, 2)) - 4
|
||||
return bc + bctable[n>>bc]
|
||||
|
||||
def gmpy_bitcount(n):
|
||||
"""Calculate bit size of the nonnegative integer n."""
|
||||
if n: return MPZ(n).numdigits(2)
|
||||
else: return 0
|
||||
|
||||
#def sage_bitcount(n):
|
||||
# if n: return MPZ(n).nbits()
|
||||
# else: return 0
|
||||
|
||||
def sage_trailing(n):
|
||||
return MPZ(n).trailing_zero_bits()
|
||||
|
||||
if BACKEND == 'gmpy':
|
||||
bitcount = gmpy_bitcount
|
||||
trailing = gmpy_trailing
|
||||
elif BACKEND == 'sage':
|
||||
sage_bitcount = sage_utils.bitcount
|
||||
bitcount = sage_bitcount
|
||||
trailing = sage_trailing
|
||||
else:
|
||||
bitcount = python_bitcount
|
||||
trailing = python_trailing
|
||||
|
||||
if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy):
|
||||
bitcount = gmpy.bit_length
|
||||
|
||||
# Used to avoid slow function calls as far as possible
|
||||
trailtable = map(trailing, range(256))
|
||||
bctable = map(bitcount, range(1024))
|
||||
|
||||
# TODO: speed up for bases 2, 4, 8, 16, ...
|
||||
|
||||
def bin_to_radix(x, xbits, base, bdigits):
|
||||
"""Changes radix of a fixed-point number; i.e., converts
|
||||
x * 2**xbits to floor(x * 10**bdigits)."""
|
||||
return x * (MPZ(base)**bdigits) >> xbits
|
||||
|
||||
stddigits = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
def small_numeral(n, base=10, digits=stddigits):
|
||||
"""Return the string numeral of a positive integer in an arbitrary
|
||||
base. Most efficient for small input."""
|
||||
if base == 10:
|
||||
return str(n)
|
||||
digs = []
|
||||
while n:
|
||||
n, digit = divmod(n, base)
|
||||
digs.append(digits[digit])
|
||||
return "".join(digs[::-1])
|
||||
|
||||
def numeral_python(n, base=10, size=0, digits=stddigits):
|
||||
"""Represent the integer n as a string of digits in the given base.
|
||||
Recursive division is used to make this function about 3x faster
|
||||
than Python's str() for converting integers to decimal strings.
|
||||
|
||||
The 'size' parameters specifies the number of digits in n; this
|
||||
number is only used to determine splitting points and need not be
|
||||
exact."""
|
||||
if n <= 0:
|
||||
if not n:
|
||||
return "0"
|
||||
return "-" + numeral(-n, base, size, digits)
|
||||
# Fast enough to do directly
|
||||
if size < 250:
|
||||
return small_numeral(n, base, digits)
|
||||
# Divide in half
|
||||
half = (size // 2) + (size & 1)
|
||||
A, B = divmod(n, base**half)
|
||||
ad = numeral(A, base, half, digits)
|
||||
bd = numeral(B, base, half, digits).rjust(half, "0")
|
||||
return ad + bd
|
||||
|
||||
def numeral_gmpy(n, base=10, size=0, digits=stddigits):
|
||||
"""Represent the integer n as a string of digits in the given base.
|
||||
Recursive division is used to make this function about 3x faster
|
||||
than Python's str() for converting integers to decimal strings.
|
||||
|
||||
The 'size' parameters specifies the number of digits in n; this
|
||||
number is only used to determine splitting points and need not be
|
||||
exact."""
|
||||
if n < 0:
|
||||
return "-" + numeral(-n, base, size, digits)
|
||||
# gmpy.digits() may cause a segmentation fault when trying to convert
|
||||
# extremely large values to a string. The size limit may need to be
|
||||
# adjusted on some platforms, but 1500000 works on Windows and Linux.
|
||||
if size < 1500000:
|
||||
return gmpy.digits(n, base)
|
||||
# Divide in half
|
||||
half = (size // 2) + (size & 1)
|
||||
A, B = divmod(n, MPZ(base)**half)
|
||||
ad = numeral(A, base, half, digits)
|
||||
bd = numeral(B, base, half, digits).rjust(half, "0")
|
||||
return ad + bd
|
||||
|
||||
if BACKEND == "gmpy":
|
||||
numeral = numeral_gmpy
|
||||
else:
|
||||
numeral = numeral_python
|
||||
|
||||
_1_800 = 1<<800
|
||||
_1_600 = 1<<600
|
||||
_1_400 = 1<<400
|
||||
_1_200 = 1<<200
|
||||
_1_100 = 1<<100
|
||||
_1_50 = 1<<50
|
||||
|
||||
def isqrt_small_python(x):
|
||||
"""
|
||||
Correctly (floor) rounded integer square root, using
|
||||
division. Fast up to ~200 digits.
|
||||
"""
|
||||
if not x:
|
||||
return x
|
||||
if x < _1_800:
|
||||
# Exact with IEEE double precision arithmetic
|
||||
if x < _1_50:
|
||||
return int(x**0.5)
|
||||
# Initial estimate can be any integer >= the true root; round up
|
||||
r = int(x**0.5 * 1.00000000000001) + 1
|
||||
else:
|
||||
bc = bitcount(x)
|
||||
n = bc//2
|
||||
r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
|
||||
# The following iteration now precisely computes floor(sqrt(x))
|
||||
# See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
|
||||
# Perspective"
|
||||
while 1:
|
||||
y = (r+x//r)>>1
|
||||
if y >= r:
|
||||
return r
|
||||
r = y
|
||||
|
||||
def isqrt_fast_python(x):
|
||||
"""
|
||||
Fast approximate integer square root, computed using division-free
|
||||
Newton iteration for large x. For random integers the result is almost
|
||||
always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
|
||||
0.1% probability. If x is very close to an exact square, the answer is
|
||||
1 ulp wrong with high probability.
|
||||
|
||||
With 0 guard bits, the largest error over a set of 10^5 random
|
||||
inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
|
||||
almost certainly guarantees a max 1 ulp error.
|
||||
"""
|
||||
# Use direct division-based iteration if sqrt(x) < 2^400
|
||||
# Assume floating-point square root accurate to within 1 ulp, then:
|
||||
# 0 Newton iterations good to 52 bits
|
||||
# 1 Newton iterations good to 104 bits
|
||||
# 2 Newton iterations good to 208 bits
|
||||
# 3 Newton iterations good to 416 bits
|
||||
if x < _1_800:
|
||||
y = int(x**0.5)
|
||||
if x >= _1_100:
|
||||
y = (y + x//y) >> 1
|
||||
if x >= _1_200:
|
||||
y = (y + x//y) >> 1
|
||||
if x >= _1_400:
|
||||
y = (y + x//y) >> 1
|
||||
return y
|
||||
bc = bitcount(x)
|
||||
guard_bits = 10
|
||||
x <<= 2*guard_bits
|
||||
bc += 2*guard_bits
|
||||
bc += (bc&1)
|
||||
hbc = bc//2
|
||||
startprec = min(50, hbc)
|
||||
# Newton iteration for 1/sqrt(x), with floating-point starting value
|
||||
r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
|
||||
pp = startprec
|
||||
for p in giant_steps(startprec, hbc):
|
||||
# r**2, scaled from real size 2**(-bc) to 2**p
|
||||
r2 = (r*r) >> (2*pp - p)
|
||||
# x*r**2, scaled from real size ~1.0 to 2**p
|
||||
xr2 = ((x >> (bc-p)) * r2) >> p
|
||||
# New value of r, scaled from real size 2**(-bc/2) to 2**p
|
||||
r = (r * ((3<<p) - xr2)) >> (pp+1)
|
||||
pp = p
|
||||
# (1/sqrt(x))*x = sqrt(x)
|
||||
return (r*(x>>hbc)) >> (p+guard_bits)
|
||||
|
||||
def sqrtrem_python(x):
|
||||
"""Correctly rounded integer (floor) square root with remainder."""
|
||||
# to check cutoff:
|
||||
# plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
|
||||
if x < _1_600:
|
||||
y = isqrt_small_python(x)
|
||||
return y, x - y*y
|
||||
y = isqrt_fast_python(x) + 1
|
||||
rem = x - y*y
|
||||
# Correct remainder
|
||||
while rem < 0:
|
||||
y -= 1
|
||||
rem += (1+2*y)
|
||||
else:
|
||||
if rem:
|
||||
while rem > 2*(1+y):
|
||||
y += 1
|
||||
rem -= (1+2*y)
|
||||
return y, rem
|
||||
|
||||
def isqrt_python(x):
|
||||
"""Integer square root with correct (floor) rounding."""
|
||||
return sqrtrem_python(x)[0]
|
||||
|
||||
def sqrt_fixed(x, prec):
|
||||
return isqrt_fast(x<<prec)
|
||||
|
||||
sqrt_fixed2 = sqrt_fixed
|
||||
|
||||
if BACKEND == 'gmpy':
|
||||
isqrt_small = isqrt_fast = isqrt = gmpy.sqrt
|
||||
sqrtrem = gmpy.sqrtrem
|
||||
elif BACKEND == 'sage':
|
||||
isqrt_small = isqrt_fast = isqrt = lambda n: MPZ(n).isqrt()
|
||||
sqrtrem = lambda n: MPZ(n).sqrtrem()
|
||||
else:
|
||||
isqrt_small = isqrt_small_python
|
||||
isqrt_fast = isqrt_fast_python
|
||||
isqrt = isqrt_python
|
||||
sqrtrem = sqrtrem_python
|
||||
|
||||
|
||||
def ifib(n, _cache={}):
|
||||
"""Computes the nth Fibonacci number as an integer, for
|
||||
integer n."""
|
||||
if n < 0:
|
||||
return (-1)**(-n+1) * ifib(-n)
|
||||
if n in _cache:
|
||||
return _cache[n]
|
||||
m = n
|
||||
# Use Dijkstra's logarithmic algorithm
|
||||
# The following implementation is basically equivalent to
|
||||
# http://en.literateprograms.org/Fibonacci_numbers_(Scheme)
|
||||
a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE
|
||||
while n:
|
||||
if n & 1:
|
||||
aq = a*q
|
||||
a, b = b*q+aq+a*p, b*p+aq
|
||||
n -= 1
|
||||
else:
|
||||
qq = q*q
|
||||
p, q = p*p+qq, qq+2*p*q
|
||||
n >>= 1
|
||||
if m < 250:
|
||||
_cache[m] = b
|
||||
return b
|
||||
|
||||
MAX_FACTORIAL_CACHE = 1000
|
||||
|
||||
def ifac(n, memo={0:1, 1:1}):
|
||||
"""Return n factorial (for integers n >= 0 only)."""
|
||||
f = memo.get(n)
|
||||
if f:
|
||||
return f
|
||||
k = len(memo)
|
||||
p = memo[k-1]
|
||||
MAX = MAX_FACTORIAL_CACHE
|
||||
while k <= n:
|
||||
p *= k
|
||||
if k <= MAX:
|
||||
memo[k] = p
|
||||
k += 1
|
||||
return p
|
||||
|
||||
if BACKEND == 'gmpy':
|
||||
ifac = gmpy.fac
|
||||
elif BACKEND == 'sage':
|
||||
ifac = lambda n: int(sage.factorial(n))
|
||||
ifib = sage.fibonacci
|
||||
|
||||
def list_primes(n):
|
||||
n = n + 1
|
||||
sieve = range(n)
|
||||
sieve[:2] = [0, 0]
|
||||
for i in xrange(2, int(n**0.5)+1):
|
||||
if sieve[i]:
|
||||
for j in xrange(i**2, n, i):
|
||||
sieve[j] = 0
|
||||
return [p for p in sieve if p]
|
||||
|
||||
if BACKEND == 'sage':
|
||||
def list_primes(n):
|
||||
return list(sage.primes(n+1))
|
||||
|
||||
def moebius(n):
|
||||
"""
|
||||
Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n`
|
||||
is a product of `k` distinct primes and `mu(n) = 0` otherwise.
|
||||
|
||||
TODO: speed up using factorization
|
||||
"""
|
||||
n = abs(int(n))
|
||||
if n < 2:
|
||||
return n
|
||||
factors = []
|
||||
for p in xrange(2, n+1):
|
||||
if not (n % p):
|
||||
if not (n % p**2):
|
||||
return 0
|
||||
if not sum(p % f for f in factors):
|
||||
factors.append(p)
|
||||
return (-1)**len(factors)
|
||||
|
||||
def gcd(*args):
|
||||
a = 0
|
||||
for b in args:
|
||||
if a:
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
else:
|
||||
a = b
|
||||
return a
|
||||
|
||||
|
||||
# Comment by Juan Arias de Reyna:
|
||||
#
|
||||
# I learn this method to compute EulerE[2n] from van de Lune.
|
||||
#
|
||||
# We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1)
|
||||
#
|
||||
# where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies
|
||||
#
|
||||
# a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0
|
||||
#
|
||||
# a(n,j) = a(n-1,j) when n+j is even
|
||||
# a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd
|
||||
#
|
||||
#
|
||||
# But we can use only one array unidimensional a(j) since to compute
|
||||
# a(n,j) we only need to know a(n-1,k) where k and j are of different parity
|
||||
# and we have not to conserve the used values.
|
||||
#
|
||||
# We cached up the values of Euler numbers to sufficiently high order.
|
||||
#
|
||||
# Important Observation: If we pretend to use the numbers
|
||||
# EulerE[1], EulerE[2], ... , EulerE[n]
|
||||
# it is convenient to compute first EulerE[n], since the algorithm
|
||||
# computes first all
|
||||
# the previous ones, and keeps them in the CACHE
|
||||
|
||||
MAX_EULER_CACHE = 500
|
||||
|
||||
def eulernum(m, _cache={0:MPZ_ONE}):
|
||||
r"""
|
||||
Computes the Euler numbers `E(n)`, which can be defined as
|
||||
coefficients of the Taylor expansion of `1/cosh x`:
|
||||
|
||||
.. math ::
|
||||
|
||||
\frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n
|
||||
|
||||
Example::
|
||||
|
||||
>>> [int(eulernum(n)) for n in range(11)]
|
||||
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
|
||||
>>> [int(eulernum(n)) for n in range(11)] # test cache
|
||||
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
|
||||
|
||||
"""
|
||||
# for odd m > 1, the Euler numbers are zero
|
||||
if m & 1:
|
||||
return MPZ_ZERO
|
||||
f = _cache.get(m)
|
||||
if f:
|
||||
return f
|
||||
MAX = MAX_EULER_CACHE
|
||||
n = m
|
||||
a = map(MPZ, [0,0,1,0,0,0])
|
||||
for n in range(1, m+1):
|
||||
for j in range(n+1, -1, -2):
|
||||
a[j+1] = (j-1)*a[j] + (j+1)*a[j+2]
|
||||
a.append(0)
|
||||
suma = 0
|
||||
for k in range(n+1, -1, -2):
|
||||
suma += a[k+1]
|
||||
if n <= MAX:
|
||||
_cache[n] = ((-1)**(n//2))*(suma // 2**n)
|
||||
if n == m:
|
||||
return ((-1)**(n//2))*suma // 2**n
|
||||
|
|
@ -1,754 +0,0 @@
|
|||
"""
|
||||
Low-level functions for complex arithmetic.
|
||||
"""
|
||||
|
||||
from backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_TWO
|
||||
|
||||
from libmpf import (\
|
||||
round_floor, round_ceiling, round_down, round_up,
|
||||
round_nearest, round_fast, bitcount,
|
||||
bctable, normalize, normalize1, reciprocal_rnd, rshift, lshift, giant_steps,
|
||||
negative_rnd,
|
||||
to_str, to_fixed, from_man_exp, from_float, to_float, from_int, to_int,
|
||||
fzero, fone, ftwo, fhalf, finf, fninf, fnan, fnone,
|
||||
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul,
|
||||
mpf_div, mpf_mul_int, mpf_shift, mpf_sqrt, mpf_hypot,
|
||||
mpf_rdiv_int, mpf_floor, mpf_ceil,
|
||||
mpf_sign,
|
||||
ComplexResult
|
||||
)
|
||||
|
||||
from libelefun import (\
|
||||
mpf_pi, mpf_exp, mpf_log, mpf_cos_sin, mpf_cosh_sinh, mpf_tan, mpf_pow_int,
|
||||
mpf_log_hypot,
|
||||
mpf_cos_sin_pi, mpf_phi,
|
||||
mpf_atan, mpf_atan2, mpf_cosh, mpf_sinh, mpf_tanh,
|
||||
mpf_asin, mpf_acos, mpf_acosh, mpf_nthroot, mpf_fibonacci
|
||||
)
|
||||
|
||||
# An mpc value is a (real, imag) tuple
|
||||
mpc_one = fone, fzero
|
||||
mpc_zero = fzero, fzero
|
||||
mpc_two = ftwo, fzero
|
||||
mpc_half = (fhalf, fzero)
|
||||
|
||||
_infs = (finf, fninf)
|
||||
_infs_nan = (finf, fninf, fnan)
|
||||
|
||||
def mpc_is_inf(z):
|
||||
"""Check if either real or imaginary part is infinite"""
|
||||
re, im = z
|
||||
if re in _infs: return True
|
||||
if im in _infs: return True
|
||||
return False
|
||||
|
||||
def mpc_is_infnan(z):
|
||||
"""Check if either real or imaginary part is infinite or nan"""
|
||||
re, im = z
|
||||
if re in _infs_nan: return True
|
||||
if im in _infs_nan: return True
|
||||
return False
|
||||
|
||||
def mpc_to_str(z, dps, **kwargs):
|
||||
re, im = z
|
||||
rs = to_str(re, dps)
|
||||
if im[0]:
|
||||
return rs + " - " + to_str(mpf_neg(im), dps, **kwargs) + "j"
|
||||
else:
|
||||
return rs + " + " + to_str(im, dps, **kwargs) + "j"
|
||||
|
||||
def mpc_to_complex(z, strict=False):
|
||||
re, im = z
|
||||
return complex(to_float(re, strict), to_float(im, strict))
|
||||
|
||||
def mpc_hash(z):
|
||||
try:
|
||||
return hash(mpc_to_complex(z, strict=True))
|
||||
except OverflowError:
|
||||
return hash(z)
|
||||
|
||||
def mpc_conjugate(z, prec, rnd=round_fast):
|
||||
re, im = z
|
||||
return re, mpf_neg(im, prec, rnd)
|
||||
|
||||
def mpc_is_nonzero(z):
|
||||
return z != mpc_zero
|
||||
|
||||
def mpc_add(z, w, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
c, d = w
|
||||
return mpf_add(a, c, prec, rnd), mpf_add(b, d, prec, rnd)
|
||||
|
||||
def mpc_add_mpf(z, x, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
return mpf_add(a, x, prec, rnd), b
|
||||
|
||||
def mpc_sub(z, w, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
c, d = w
|
||||
return mpf_sub(a, c, prec, rnd), mpf_sub(b, d, prec, rnd)
|
||||
|
||||
def mpc_sub_mpf(z, p, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
return mpf_sub(a, p, prec, rnd), b
|
||||
|
||||
def mpc_pos(z, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
return mpf_pos(a, prec, rnd), mpf_pos(b, prec, rnd)
|
||||
|
||||
def mpc_neg(z, prec=None, rnd=round_fast):
|
||||
a, b = z
|
||||
return mpf_neg(a, prec, rnd), mpf_neg(b, prec, rnd)
|
||||
|
||||
def mpc_shift(z, n):
|
||||
a, b = z
|
||||
return mpf_shift(a, n), mpf_shift(b, n)
|
||||
|
||||
def mpc_abs(z, prec, rnd=round_fast):
|
||||
"""Absolute value of a complex number, |a+bi|.
|
||||
Returns an mpf value."""
|
||||
a, b = z
|
||||
return mpf_hypot(a, b, prec, rnd)
|
||||
|
||||
def mpc_arg(z, prec, rnd=round_fast):
|
||||
"""Argument of a complex number. Returns an mpf value."""
|
||||
a, b = z
|
||||
return mpf_atan2(b, a, prec, rnd)
|
||||
|
||||
def mpc_floor(z, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
return mpf_floor(a, prec, rnd), mpf_floor(b, prec, rnd)
|
||||
|
||||
def mpc_ceil(z, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
return mpf_ceil(a, prec, rnd), mpf_ceil(b, prec, rnd)
|
||||
|
||||
def mpc_mul(z, w, prec, rnd=round_fast):
|
||||
"""
|
||||
Complex multiplication.
|
||||
|
||||
Returns the real and imaginary part of (a+bi)*(c+di), rounded to
|
||||
the specified precision. The rounding mode applies to the real and
|
||||
imaginary parts separately.
|
||||
"""
|
||||
a, b = z
|
||||
c, d = w
|
||||
p = mpf_mul(a, c)
|
||||
q = mpf_mul(b, d)
|
||||
r = mpf_mul(a, d)
|
||||
s = mpf_mul(b, c)
|
||||
re = mpf_sub(p, q, prec, rnd)
|
||||
im = mpf_add(r, s, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_square(z, prec, rnd=round_fast):
|
||||
# (a+b*I)**2 == a**2 - b**2 + 2*I*a*b
|
||||
a, b = z
|
||||
p = mpf_mul(a,a)
|
||||
q = mpf_mul(b,b)
|
||||
r = mpf_mul(a,b, prec, rnd)
|
||||
re = mpf_sub(p, q, prec, rnd)
|
||||
im = mpf_shift(r, 1)
|
||||
return re, im
|
||||
|
||||
def mpc_mul_mpf(z, p, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
re = mpf_mul(a, p, prec, rnd)
|
||||
im = mpf_mul(b, p, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_mul_imag_mpf(z, x, prec, rnd=round_fast):
|
||||
"""
|
||||
Multiply the mpc value z by I*x where x is an mpf value.
|
||||
"""
|
||||
a, b = z
|
||||
re = mpf_neg(mpf_mul(b, x, prec, rnd))
|
||||
im = mpf_mul(a, x, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_mul_int(z, n, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
re = mpf_mul_int(a, n, prec, rnd)
|
||||
im = mpf_mul_int(b, n, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_div(z, w, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
c, d = w
|
||||
wp = prec + 10
|
||||
# mag = c*c + d*d
|
||||
mag = mpf_add(mpf_mul(c, c), mpf_mul(d, d), wp)
|
||||
# (a*c+b*d)/mag, (b*c-a*d)/mag
|
||||
t = mpf_add(mpf_mul(a,c), mpf_mul(b,d), wp)
|
||||
u = mpf_sub(mpf_mul(b,c), mpf_mul(a,d), wp)
|
||||
return mpf_div(t,mag,prec,rnd), mpf_div(u,mag,prec,rnd)
|
||||
|
||||
def mpc_div_mpf(z, p, prec, rnd=round_fast):
|
||||
"""Calculate z/p where p is real"""
|
||||
a, b = z
|
||||
re = mpf_div(a, p, prec, rnd)
|
||||
im = mpf_div(b, p, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_reciprocal(z, prec, rnd=round_fast):
|
||||
"""Calculate 1/z efficiently"""
|
||||
a, b = z
|
||||
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b),prec+10)
|
||||
re = mpf_div(a, m, prec, rnd)
|
||||
im = mpf_neg(mpf_div(b, m, prec, rnd))
|
||||
return re, im
|
||||
|
||||
def mpc_mpf_div(p, z, prec, rnd=round_fast):
|
||||
"""Calculate p/z where p is real efficiently"""
|
||||
a, b = z
|
||||
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b), prec+10)
|
||||
re = mpf_div(mpf_mul(a,p), m, prec, rnd)
|
||||
im = mpf_div(mpf_neg(mpf_mul(b,p)), m, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def complex_int_pow(a, b, n):
|
||||
"""Complex integer power: computes (a+b*I)**n exactly for
|
||||
nonnegative n (a and b must be Python ints)."""
|
||||
wre = 1
|
||||
wim = 0
|
||||
while n:
|
||||
if n & 1:
|
||||
wre, wim = wre*a - wim*b, wim*a + wre*b
|
||||
n -= 1
|
||||
a, b = a*a - b*b, 2*a*b
|
||||
n //= 2
|
||||
return wre, wim
|
||||
|
||||
def mpc_pow(z, w, prec, rnd=round_fast):
|
||||
if w[1] == fzero:
|
||||
return mpc_pow_mpf(z, w[0], prec, rnd)
|
||||
return mpc_exp(mpc_mul(mpc_log(z, prec+10), w, prec+10), prec, rnd)
|
||||
|
||||
def mpc_pow_mpf(z, p, prec, rnd=round_fast):
|
||||
psign, pman, pexp, pbc = p
|
||||
if pexp >= 0:
|
||||
return mpc_pow_int(z, (-1)**psign * (pman<<pexp), prec, rnd)
|
||||
if pexp == -1:
|
||||
sqrtz = mpc_sqrt(z, prec+10)
|
||||
return mpc_pow_int(sqrtz, (-1)**psign * pman, prec, rnd)
|
||||
return mpc_exp(mpc_mul_mpf(mpc_log(z, prec+10), p, prec+10), prec, rnd)
|
||||
|
||||
def mpc_pow_int(z, n, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
if b == fzero:
|
||||
return mpf_pow_int(a, n, prec, rnd), fzero
|
||||
if a == fzero:
|
||||
v = mpf_pow_int(b, n, prec, rnd)
|
||||
n %= 4
|
||||
if n == 0:
|
||||
return v, fzero
|
||||
elif n == 1:
|
||||
return fzero, v
|
||||
elif n == 2:
|
||||
return mpf_neg(v), fzero
|
||||
elif n == 3:
|
||||
return fzero, mpf_neg(v)
|
||||
if n == 0: return mpc_one
|
||||
if n == 1: return mpc_pos(z, prec, rnd)
|
||||
if n == 2: return mpc_square(z, prec, rnd)
|
||||
if n == -1: return mpc_reciprocal(z, prec, rnd)
|
||||
if n < 0: return mpc_reciprocal(mpc_pow_int(z, -n, prec+4), prec, rnd)
|
||||
asign, aman, aexp, abc = a
|
||||
bsign, bman, bexp, bbc = b
|
||||
if asign: aman = -aman
|
||||
if bsign: bman = -bman
|
||||
de = aexp - bexp
|
||||
abs_de = abs(de)
|
||||
exact_size = n*(abs_de + max(abc, bbc))
|
||||
if exact_size < 10000:
|
||||
if de > 0:
|
||||
aman <<= de
|
||||
aexp = bexp
|
||||
else:
|
||||
bman <<= (-de)
|
||||
bexp = aexp
|
||||
re, im = complex_int_pow(aman, bman, n)
|
||||
re = from_man_exp(re, int(n*aexp), prec, rnd)
|
||||
im = from_man_exp(im, int(n*bexp), prec, rnd)
|
||||
return re, im
|
||||
return mpc_exp(mpc_mul_int(mpc_log(z, prec+10), n, prec+10), prec, rnd)
|
||||
|
||||
def mpc_sqrt(z, prec, rnd=round_fast):
|
||||
"""Complex square root (principal branch).
|
||||
|
||||
We have sqrt(a+bi) = sqrt((r+a)/2) + b/sqrt(2*(r+a))*i where
|
||||
r = abs(a+bi), when a+bi is not a negative real number."""
|
||||
a, b = z
|
||||
if b == fzero:
|
||||
if a == fzero:
|
||||
return (a, b)
|
||||
# When a+bi is a negative real number, we get a real sqrt times i
|
||||
if a[0]:
|
||||
im = mpf_sqrt(mpf_neg(a), prec, rnd)
|
||||
return (fzero, im)
|
||||
else:
|
||||
re = mpf_sqrt(a, prec, rnd)
|
||||
return (re, fzero)
|
||||
wp = prec+20
|
||||
if not a[0]: # case a positive
|
||||
t = mpf_add(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) + a
|
||||
u = mpf_shift(t, -1) # u = t/2
|
||||
re = mpf_sqrt(u, prec, rnd) # re = sqrt(u)
|
||||
v = mpf_shift(t, 1) # v = 2*t
|
||||
w = mpf_sqrt(v, wp) # w = sqrt(v)
|
||||
im = mpf_div(b, w, prec, rnd) # im = b / w
|
||||
else: # case a negative
|
||||
t = mpf_sub(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) - a
|
||||
u = mpf_shift(t, -1) # u = t/2
|
||||
im = mpf_sqrt(u, prec, rnd) # im = sqrt(u)
|
||||
v = mpf_shift(t, 1) # v = 2*t
|
||||
w = mpf_sqrt(v, wp) # w = sqrt(v)
|
||||
re = mpf_div(b, w, prec, rnd) # re = b/w
|
||||
if b[0]:
|
||||
re = mpf_neg(re)
|
||||
im = mpf_neg(im)
|
||||
return re, im
|
||||
|
||||
def mpc_nthroot_fixed(a, b, n, prec):
|
||||
# a, b signed integers at fixed precision prec
|
||||
start = 50
|
||||
a1 = int(rshift(a, prec - n*start))
|
||||
b1 = int(rshift(b, prec - n*start))
|
||||
try:
|
||||
r = (a1 + 1j * b1)**(1.0/n)
|
||||
re = r.real
|
||||
im = r.imag
|
||||
re = MPZ(int(re))
|
||||
im = MPZ(int(im))
|
||||
except OverflowError:
|
||||
a1 = from_int(a1, start)
|
||||
b1 = from_int(b1, start)
|
||||
fn = from_int(n)
|
||||
nth = mpf_rdiv_int(1, fn, start)
|
||||
re, im = mpc_pow((a1, b1), (nth, fzero), start)
|
||||
re = to_int(re)
|
||||
im = to_int(im)
|
||||
extra = 10
|
||||
prevp = start
|
||||
extra1 = n
|
||||
for p in giant_steps(start, prec+extra):
|
||||
# this is slow for large n, unlike int_pow_fixed
|
||||
re2, im2 = complex_int_pow(re, im, n-1)
|
||||
re2 = rshift(re2, (n-1)*prevp - p - extra1)
|
||||
im2 = rshift(im2, (n-1)*prevp - p - extra1)
|
||||
r4 = (re2*re2 + im2*im2) >> (p + extra1)
|
||||
ap = rshift(a, prec - p)
|
||||
bp = rshift(b, prec - p)
|
||||
rec = (ap * re2 + bp * im2) >> p
|
||||
imc = (-ap * im2 + bp * re2) >> p
|
||||
reb = (rec << p) // r4
|
||||
imb = (imc << p) // r4
|
||||
re = (reb + (n-1)*lshift(re, p-prevp))//n
|
||||
im = (imb + (n-1)*lshift(im, p-prevp))//n
|
||||
prevp = p
|
||||
return re, im
|
||||
|
||||
def mpc_nthroot(z, n, prec, rnd=round_fast):
|
||||
"""
|
||||
Complex n-th root.
|
||||
|
||||
Use Newton method as in the real case when it is faster,
|
||||
otherwise use z**(1/n)
|
||||
"""
|
||||
a, b = z
|
||||
if a[0] == 0 and b == fzero:
|
||||
re = mpf_nthroot(a, n, prec, rnd)
|
||||
return (re, fzero)
|
||||
if n < 2:
|
||||
if n == 0:
|
||||
return mpc_one
|
||||
if n == 1:
|
||||
return mpc_pos((a, b), prec, rnd)
|
||||
if n == -1:
|
||||
return mpc_div(mpc_one, (a, b), prec, rnd)
|
||||
inverse = mpc_nthroot((a, b), -n, prec+5, reciprocal_rnd[rnd])
|
||||
return mpc_div(mpc_one, inverse, prec, rnd)
|
||||
if n <= 20:
|
||||
prec2 = int(1.2 * (prec + 10))
|
||||
asign, aman, aexp, abc = a
|
||||
bsign, bman, bexp, bbc = b
|
||||
pf = mpc_abs((a,b), prec)
|
||||
if pf[-2] + pf[-1] > -10 and pf[-2] + pf[-1] < prec:
|
||||
af = to_fixed(a, prec2)
|
||||
bf = to_fixed(b, prec2)
|
||||
re, im = mpc_nthroot_fixed(af, bf, n, prec2)
|
||||
extra = 10
|
||||
re = from_man_exp(re, -prec2-extra, prec2, rnd)
|
||||
im = from_man_exp(im, -prec2-extra, prec2, rnd)
|
||||
return re, im
|
||||
fn = from_int(n)
|
||||
prec2 = prec+10 + 10
|
||||
nth = mpf_rdiv_int(1, fn, prec2)
|
||||
re, im = mpc_pow((a, b), (nth, fzero), prec2, rnd)
|
||||
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
|
||||
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_cbrt((a, b), prec, rnd=round_fast):
|
||||
"""
|
||||
Complex cubic root.
|
||||
"""
|
||||
return mpc_nthroot((a, b), 3, prec, rnd)
|
||||
|
||||
def mpc_exp((a, b), prec, rnd=round_fast):
|
||||
"""
|
||||
Complex exponential function.
|
||||
|
||||
We use the direct formula exp(a+bi) = exp(a) * (cos(b) + sin(b)*i)
|
||||
for the computation. This formula is very nice because it is
|
||||
pefectly stable; since we just do real multiplications, the only
|
||||
numerical errors that can creep in are single-ulp rounding errors.
|
||||
|
||||
The formula is efficient since mpmath's real exp is quite fast and
|
||||
since we can compute cos and sin simultaneously.
|
||||
|
||||
It is no problem if a and b are large; if the implementations of
|
||||
exp/cos/sin are accurate and efficient for all real numbers, then
|
||||
so is this function for all complex numbers.
|
||||
"""
|
||||
if a == fzero:
|
||||
return mpf_cos_sin(b, prec, rnd)
|
||||
mag = mpf_exp(a, prec+4, rnd)
|
||||
c, s = mpf_cos_sin(b, prec+4, rnd)
|
||||
re = mpf_mul(mag, c, prec, rnd)
|
||||
im = mpf_mul(mag, s, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_log(z, prec, rnd=round_fast):
|
||||
re = mpf_log_hypot(z[0], z[1], prec, rnd)
|
||||
im = mpc_arg(z, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_cos((a, b), prec, rnd=round_fast):
|
||||
"""Complex cosine. The formula used is cos(a+bi) = cos(a)*cosh(b) -
|
||||
sin(a)*sinh(b)*i.
|
||||
|
||||
The same comments apply as for the complex exp: only real
|
||||
multiplications are pewrormed, so no cancellation errors are
|
||||
possible. The formula is also efficient since we can compute both
|
||||
pairs (cos, sin) and (cosh, sinh) in single stwps."""
|
||||
if a == fzero:
|
||||
return mpf_cosh(b, prec, rnd), fzero
|
||||
wp = prec + 6
|
||||
c, s = mpf_cos_sin(a, wp)
|
||||
ch, sh = mpf_cosh_sinh(b, wp)
|
||||
re = mpf_mul(c, ch, prec, rnd)
|
||||
im = mpf_mul(s, sh, prec, rnd)
|
||||
return re, mpf_neg(im)
|
||||
|
||||
def mpc_sin((a, b), prec, rnd=round_fast):
|
||||
"""Complex sine. We have sin(a+bi) = sin(a)*cosh(b) +
|
||||
cos(a)*sinh(b)*i. See the docstring for mpc_cos for additional
|
||||
comments."""
|
||||
if a == fzero:
|
||||
return fzero, mpf_sinh(b, prec, rnd)
|
||||
wp = prec + 6
|
||||
c, s = mpf_cos_sin(a, wp)
|
||||
ch, sh = mpf_cosh_sinh(b, wp)
|
||||
re = mpf_mul(s, ch, prec, rnd)
|
||||
im = mpf_mul(c, sh, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_tan(z, prec, rnd=round_fast):
|
||||
"""Complex tangent. Computed as tan(a+bi) = sin(2a)/M + sinh(2b)/M*i
|
||||
where M = cos(2a) + cosh(2b)."""
|
||||
a, b = z
|
||||
asign, aman, aexp, abc = a
|
||||
bsign, bman, bexp, bbc = b
|
||||
if b == fzero: return mpf_tan(a, prec, rnd), fzero
|
||||
if a == fzero: return fzero, mpf_tanh(b, prec, rnd)
|
||||
wp = prec + 15
|
||||
a = mpf_shift(a, 1)
|
||||
b = mpf_shift(b, 1)
|
||||
c, s = mpf_cos_sin(a, wp)
|
||||
ch, sh = mpf_cosh_sinh(b, wp)
|
||||
# TODO: handle cancellation when c ~= -1 and ch ~= 1
|
||||
mag = mpf_add(c, ch, wp)
|
||||
re = mpf_div(s, mag, prec, rnd)
|
||||
im = mpf_div(sh, mag, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_cos_pi((a, b), prec, rnd=round_fast):
|
||||
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
|
||||
if a == fzero:
|
||||
return mpf_cosh(b, prec, rnd), fzero
|
||||
wp = prec + 6
|
||||
c, s = mpf_cos_sin_pi(a, wp)
|
||||
ch, sh = mpf_cosh_sinh(b, wp)
|
||||
re = mpf_mul(c, ch, prec, rnd)
|
||||
im = mpf_mul(s, sh, prec, rnd)
|
||||
return re, mpf_neg(im)
|
||||
|
||||
def mpc_sin_pi((a, b), prec, rnd=round_fast):
|
||||
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
|
||||
if a == fzero:
|
||||
return fzero, mpf_sinh(b, prec, rnd)
|
||||
wp = prec + 6
|
||||
c, s = mpf_cos_sin_pi(a, wp)
|
||||
ch, sh = mpf_cosh_sinh(b, wp)
|
||||
re = mpf_mul(s, ch, prec, rnd)
|
||||
im = mpf_mul(c, sh, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_cosh((a, b), prec, rnd=round_fast):
|
||||
"""Complex hyperbolic cosine. Computed as cosh(z) = cos(z*i)."""
|
||||
return mpc_cos((b, mpf_neg(a)), prec, rnd)
|
||||
|
||||
def mpc_sinh((a, b), prec, rnd=round_fast):
|
||||
"""Complex hyperbolic sine. Computed as sinh(z) = -i*sin(z*i)."""
|
||||
b, a = mpc_sin((b, a), prec, rnd)
|
||||
return a, b
|
||||
|
||||
def mpc_tanh((a, b), prec, rnd=round_fast):
|
||||
"""Complex hyperbolic tangent. Computed as tanh(z) = -i*tan(z*i)."""
|
||||
b, a = mpc_tan((b, a), prec, rnd)
|
||||
return a, b
|
||||
|
||||
# TODO: avoid loss of accuracy
|
||||
def mpc_atan(z, prec, rnd=round_fast):
|
||||
a, b = z
|
||||
# atan(z) = (I/2)*(log(1-I*z) - log(1+I*z))
|
||||
# x = 1-I*z = 1 + b - I*a
|
||||
# y = 1+I*z = 1 - b + I*a
|
||||
wp = prec + 15
|
||||
x = mpf_add(fone, b, wp), mpf_neg(a)
|
||||
y = mpf_sub(fone, b, wp), a
|
||||
l1 = mpc_log(x, wp)
|
||||
l2 = mpc_log(y, wp)
|
||||
a, b = mpc_sub(l1, l2, prec, rnd)
|
||||
# (I/2) * (a+b*I) = (-b/2 + a/2*I)
|
||||
v = mpf_neg(mpf_shift(b,-1)), mpf_shift(a,-1)
|
||||
# Subtraction at infinity gives correct real part but
|
||||
# wrong imaginary part (should be zero)
|
||||
if v[1] == fnan and mpc_is_inf(z):
|
||||
v = (v[0], fzero)
|
||||
return v
|
||||
|
||||
beta_crossover = from_float(0.6417)
|
||||
alpha_crossover = from_float(1.5)
|
||||
|
||||
def acos_asin(z, prec, rnd, n):
|
||||
""" complex acos for n = 0, asin for n = 1
|
||||
The algorithm is described in
|
||||
T.E. Hull, T.F. Fairgrieve and P.T.P. Tang
|
||||
'Implementing the Complex Arcsine and Arcosine Functions
|
||||
using Exception Handling',
|
||||
ACM Trans. on Math. Software Vol. 23 (1997), p299
|
||||
The complex acos and asin can be defined as
|
||||
acos(z) = acos(beta) - I*sign(a)* log(alpha + sqrt(alpha**2 -1))
|
||||
asin(z) = asin(beta) + I*sign(a)* log(alpha + sqrt(alpha**2 -1))
|
||||
where z = a + I*b
|
||||
alpha = (1/2)*(r + s); beta = (1/2)*(r - s) = a/alpha
|
||||
r = sqrt((a+1)**2 + y**2); s = sqrt((a-1)**2 + y**2)
|
||||
These expressions are rewritten in different ways in different
|
||||
regions, delimited by two crossovers alpha_crossover and beta_crossover,
|
||||
and by abs(a) <= 1, in order to improve the numerical accuracy.
|
||||
"""
|
||||
a, b = z
|
||||
wp = prec + 10
|
||||
# special cases with real argument
|
||||
if b == fzero:
|
||||
am = mpf_sub(fone, mpf_abs(a), wp)
|
||||
# case abs(a) <= 1
|
||||
if not am[0]:
|
||||
if n == 0:
|
||||
return mpf_acos(a, prec, rnd), fzero
|
||||
else:
|
||||
return mpf_asin(a, prec, rnd), fzero
|
||||
# cases abs(a) > 1
|
||||
else:
|
||||
# case a < -1
|
||||
if a[0]:
|
||||
pi = mpf_pi(prec, rnd)
|
||||
c = mpf_acosh(mpf_neg(a), prec, rnd)
|
||||
if n == 0:
|
||||
return pi, mpf_neg(c)
|
||||
else:
|
||||
return mpf_neg(mpf_shift(pi, -1)), c
|
||||
# case a > 1
|
||||
else:
|
||||
c = mpf_acosh(a, prec, rnd)
|
||||
if n == 0:
|
||||
return fzero, c
|
||||
else:
|
||||
pi = mpf_pi(prec, rnd)
|
||||
return mpf_shift(pi, -1), mpf_neg(c)
|
||||
asign = bsign = 0
|
||||
if a[0]:
|
||||
a = mpf_neg(a)
|
||||
asign = 1
|
||||
if b[0]:
|
||||
b = mpf_neg(b)
|
||||
bsign = 1
|
||||
am = mpf_sub(fone, a, wp)
|
||||
ap = mpf_add(fone, a, wp)
|
||||
r = mpf_hypot(ap, b, wp)
|
||||
s = mpf_hypot(am, b, wp)
|
||||
alpha = mpf_shift(mpf_add(r, s, wp), -1)
|
||||
beta = mpf_div(a, alpha, wp)
|
||||
b2 = mpf_mul(b,b, wp)
|
||||
# case beta <= beta_crossover
|
||||
if not mpf_sub(beta_crossover, beta, wp)[0]:
|
||||
if n == 0:
|
||||
re = mpf_acos(beta, wp)
|
||||
else:
|
||||
re = mpf_asin(beta, wp)
|
||||
else:
|
||||
# to compute the real part in this region use the identity
|
||||
# asin(beta) = atan(beta/sqrt(1-beta**2))
|
||||
# beta/sqrt(1-beta**2) = (alpha + a) * (alpha - a)
|
||||
# alpha + a is numerically accurate; alpha - a can have
|
||||
# cancellations leading to numerical inaccuracies, so rewrite
|
||||
# it in differente ways according to the region
|
||||
Ax = mpf_add(alpha, a, wp)
|
||||
# case a <= 1
|
||||
if not am[0]:
|
||||
# c = b*b/(r + (a+1)); d = (s + (1-a))
|
||||
# alpha - a = (1/2)*(c + d)
|
||||
# case n=0: re = atan(sqrt((1/2) * Ax * (c + d))/a)
|
||||
# case n=1: re = atan(a/sqrt((1/2) * Ax * (c + d)))
|
||||
c = mpf_div(b2, mpf_add(r, ap, wp), wp)
|
||||
d = mpf_add(s, am, wp)
|
||||
re = mpf_shift(mpf_mul(Ax, mpf_add(c, d, wp), wp), -1)
|
||||
if n == 0:
|
||||
re = mpf_atan(mpf_div(mpf_sqrt(re, wp), a, wp), wp)
|
||||
else:
|
||||
re = mpf_atan(mpf_div(a, mpf_sqrt(re, wp), wp), wp)
|
||||
else:
|
||||
# c = Ax/(r + (a+1)); d = Ax/(s - (1-a))
|
||||
# alpha - a = (1/2)*(c + d)
|
||||
# case n = 0: re = atan(b*sqrt(c + d)/2/a)
|
||||
# case n = 1: re = atan(a/(b*sqrt(c + d)/2)
|
||||
c = mpf_div(Ax, mpf_add(r, ap, wp), wp)
|
||||
d = mpf_div(Ax, mpf_sub(s, am, wp), wp)
|
||||
re = mpf_shift(mpf_add(c, d, wp), -1)
|
||||
re = mpf_mul(b, mpf_sqrt(re, wp), wp)
|
||||
if n == 0:
|
||||
re = mpf_atan(mpf_div(re, a, wp), wp)
|
||||
else:
|
||||
re = mpf_atan(mpf_div(a, re, wp), wp)
|
||||
# to compute alpha + sqrt(alpha**2 - 1), if alpha <= alpha_crossover
|
||||
# replace it with 1 + Am1 + sqrt(Am1*(alpha+1)))
|
||||
# where Am1 = alpha -1
|
||||
# if alpha <= alpha_crossover:
|
||||
if not mpf_sub(alpha_crossover, alpha, wp)[0]:
|
||||
c1 = mpf_div(b2, mpf_add(r, ap, wp), wp)
|
||||
# case a < 1
|
||||
if mpf_neg(am)[0]:
|
||||
# Am1 = (1/2) * (b*b/(r + (a+1)) + b*b/(s + (1-a))
|
||||
c2 = mpf_add(s, am, wp)
|
||||
c2 = mpf_div(b2, c2, wp)
|
||||
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
|
||||
else:
|
||||
# Am1 = (1/2) * (b*b/(r + (a+1)) + (s - (1-a)))
|
||||
c2 = mpf_sub(s, am, wp)
|
||||
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
|
||||
# im = log(1 + Am1 + sqrt(Am1*(alpha+1)))
|
||||
im = mpf_mul(Am1, mpf_add(alpha, fone, wp), wp)
|
||||
im = mpf_log(mpf_add(fone, mpf_add(Am1, mpf_sqrt(im, wp), wp), wp), wp)
|
||||
else:
|
||||
# im = log(alpha + sqrt(alpha*alpha - 1))
|
||||
im = mpf_sqrt(mpf_sub(mpf_mul(alpha, alpha, wp), fone, wp), wp)
|
||||
im = mpf_log(mpf_add(alpha, im, wp), wp)
|
||||
if asign:
|
||||
if n == 0:
|
||||
re = mpf_sub(mpf_pi(wp), re, wp)
|
||||
else:
|
||||
re = mpf_neg(re)
|
||||
if not bsign and n == 0:
|
||||
im = mpf_neg(im)
|
||||
if bsign and n == 1:
|
||||
im = mpf_neg(im)
|
||||
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
|
||||
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpc_acos(z, prec, rnd=round_fast):
|
||||
return acos_asin(z, prec, rnd, 0)
|
||||
|
||||
def mpc_asin(z, prec, rnd=round_fast):
|
||||
return acos_asin(z, prec, rnd, 1)
|
||||
|
||||
def mpc_asinh(z, prec, rnd=round_fast):
|
||||
# asinh(z) = I * asin(-I z)
|
||||
a, b = z
|
||||
a, b = mpc_asin((b, mpf_neg(a)), prec, rnd)
|
||||
return mpf_neg(b), a
|
||||
|
||||
def mpc_acosh(z, prec, rnd=round_fast):
|
||||
# acosh(z) = -I * acos(z) for Im(acos(z)) <= 0
|
||||
# +I * acos(z) otherwise
|
||||
a, b = mpc_acos(z, prec, rnd)
|
||||
if b[0] or b == fzero:
|
||||
return mpf_neg(b), a
|
||||
else:
|
||||
return b, mpf_neg(a)
|
||||
|
||||
def mpc_atanh(z, prec, rnd=round_fast):
|
||||
# atanh(z) = (log(1+z)-log(1-z))/2
|
||||
wp = prec + 15
|
||||
a = mpc_add(z, mpc_one, wp)
|
||||
b = mpc_sub(mpc_one, z, wp)
|
||||
a = mpc_log(a, wp)
|
||||
b = mpc_log(b, wp)
|
||||
v = mpc_shift(mpc_sub(a, b, wp), -1)
|
||||
# Subtraction at infinity gives correct imaginary part but
|
||||
# wrong real part (should be zero)
|
||||
if v[0] == fnan and mpc_is_inf(z):
|
||||
v = (fzero, v[1])
|
||||
return v
|
||||
|
||||
def mpc_fibonacci(z, prec, rnd=round_fast):
|
||||
re, im = z
|
||||
if im == fzero:
|
||||
return (mpf_fibonacci(re, prec, rnd), fzero)
|
||||
size = max(abs(re[2]+re[3]), abs(re[2]+re[3]))
|
||||
wp = prec + size + 20
|
||||
a = mpf_phi(wp)
|
||||
b = mpf_add(mpf_shift(a, 1), fnone, wp)
|
||||
u = mpc_pow((a, fzero), z, wp)
|
||||
v = mpc_cos_pi(z, wp)
|
||||
v = mpc_div(v, u, wp)
|
||||
u = mpc_sub(u, v, wp)
|
||||
u = mpc_div_mpf(u, b, prec, rnd)
|
||||
return u
|
||||
|
||||
def mpf_expj(x, prec, rnd='f'):
|
||||
raise ComplexResult
|
||||
|
||||
def mpc_expj(z, prec, rnd='f'):
|
||||
re, im = z
|
||||
if im == fzero:
|
||||
return mpf_cos_sin(re, prec, rnd)
|
||||
if re == fzero:
|
||||
return mpf_exp(mpf_neg(im), prec, rnd), fzero
|
||||
ey = mpf_exp(mpf_neg(im), prec+10)
|
||||
c, s = mpf_cos_sin(re, prec+10)
|
||||
re = mpf_mul(ey, c, prec, rnd)
|
||||
im = mpf_mul(ey, s, prec, rnd)
|
||||
return re, im
|
||||
|
||||
def mpf_expjpi(x, prec, rnd='f'):
|
||||
raise ComplexResult
|
||||
|
||||
def mpc_expjpi(z, prec, rnd='f'):
|
||||
re, im = z
|
||||
if im == fzero:
|
||||
return mpf_cos_sin_pi(re, prec, rnd)
|
||||
sign, man, exp, bc = im
|
||||
wp = prec+10
|
||||
if man:
|
||||
wp += max(0, exp+bc)
|
||||
im = mpf_neg(mpf_mul(mpf_pi(wp), im, wp))
|
||||
if re == fzero:
|
||||
return mpf_exp(im, prec, rnd), fzero
|
||||
ey = mpf_exp(im, prec+10)
|
||||
c, s = mpf_cos_sin_pi(re, prec+10)
|
||||
re = mpf_mul(ey, c, prec, rnd)
|
||||
im = mpf_mul(ey, s, prec, rnd)
|
||||
return re, im
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,348 +0,0 @@
|
|||
"""
|
||||
Computational functions for interval arithmetic.
|
||||
|
||||
"""
|
||||
|
||||
from libmpf import (
|
||||
ComplexResult,
|
||||
round_down, round_up, round_floor, round_ceiling, round_nearest,
|
||||
prec_to_dps, repr_dps,
|
||||
fnan, finf, fninf, fzero, fhalf, fone, fnone,
|
||||
mpf_sign, mpf_lt, mpf_le, mpf_gt, mpf_ge, mpf_eq, mpf_cmp,
|
||||
mpf_floor, from_int, to_int, to_str,
|
||||
mpf_abs, mpf_neg, mpf_pos, mpf_add, mpf_sub, mpf_mul,
|
||||
mpf_div, mpf_shift, mpf_pow_int)
|
||||
|
||||
from libelefun import (
|
||||
mpf_log, mpf_exp, mpf_sqrt, reduce_angle, calc_cos_sin
|
||||
)
|
||||
|
||||
def mpi_str(s, prec):
|
||||
sa, sb = s
|
||||
dps = prec_to_dps(prec) + 5
|
||||
return "[%s, %s]" % (to_str(sa, dps), to_str(sb, dps))
|
||||
|
||||
#dps = prec_to_dps(prec)
|
||||
#m = mpi_mid(s, prec)
|
||||
#d = mpf_shift(mpi_delta(s, 20), -1)
|
||||
#return "%s +/- %s" % (to_str(m, dps), to_str(d, 3))
|
||||
|
||||
def mpi_add(s, t, prec):
|
||||
sa, sb = s
|
||||
ta, tb = t
|
||||
a = mpf_add(sa, ta, prec, round_floor)
|
||||
b = mpf_add(sb, tb, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = finf
|
||||
return a, b
|
||||
|
||||
def mpi_sub(s, t, prec):
|
||||
sa, sb = s
|
||||
ta, tb = t
|
||||
a = mpf_sub(sa, tb, prec, round_floor)
|
||||
b = mpf_sub(sb, ta, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = finf
|
||||
return a, b
|
||||
|
||||
def mpi_delta(s, prec):
|
||||
sa, sb = s
|
||||
return mpf_sub(sb, sa, prec, round_up)
|
||||
|
||||
def mpi_mid(s, prec):
|
||||
sa, sb = s
|
||||
return mpf_shift(mpf_add(sa, sb, prec, round_nearest), -1)
|
||||
|
||||
def mpi_pos(s, prec):
|
||||
sa, sb = s
|
||||
a = mpf_pos(sa, prec, round_floor)
|
||||
b = mpf_pos(sb, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_neg(s, prec=None):
|
||||
sa, sb = s
|
||||
a = mpf_neg(sb, prec, round_floor)
|
||||
b = mpf_neg(sa, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_abs(s, prec):
|
||||
sa, sb = s
|
||||
sas = mpf_sign(sa)
|
||||
sbs = mpf_sign(sb)
|
||||
# Both points nonnegative?
|
||||
if sas >= 0:
|
||||
a = mpf_pos(sa, prec, round_floor)
|
||||
b = mpf_pos(sb, prec, round_ceiling)
|
||||
# Upper point nonnegative?
|
||||
elif sbs >= 0:
|
||||
a = fzero
|
||||
negsa = mpf_neg(sa)
|
||||
if mpf_lt(negsa, sb):
|
||||
b = mpf_pos(sb, prec, round_ceiling)
|
||||
else:
|
||||
b = mpf_pos(negsa, prec, round_ceiling)
|
||||
# Both negative?
|
||||
else:
|
||||
a = mpf_neg(sb, prec, round_floor)
|
||||
b = mpf_neg(sa, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_mul(s, t, prec):
|
||||
sa, sb = s
|
||||
ta, tb = t
|
||||
sas = mpf_sign(sa)
|
||||
sbs = mpf_sign(sb)
|
||||
tas = mpf_sign(ta)
|
||||
tbs = mpf_sign(tb)
|
||||
if sas == sbs == 0:
|
||||
# Should maybe be undefined
|
||||
if ta == fninf or tb == finf:
|
||||
return fninf, finf
|
||||
return fzero, fzero
|
||||
if tas == tbs == 0:
|
||||
# Should maybe be undefined
|
||||
if sa == fninf or sb == finf:
|
||||
return fninf, finf
|
||||
return fzero, fzero
|
||||
if sas >= 0:
|
||||
# positive * positive
|
||||
if tas >= 0:
|
||||
a = mpf_mul(sa, ta, prec, round_floor)
|
||||
b = mpf_mul(sb, tb, prec, round_ceiling)
|
||||
if a == fnan: a = fzero
|
||||
if b == fnan: b = finf
|
||||
# positive * negative
|
||||
elif tbs <= 0:
|
||||
a = mpf_mul(sb, ta, prec, round_floor)
|
||||
b = mpf_mul(sa, tb, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = fzero
|
||||
# positive * both signs
|
||||
else:
|
||||
a = mpf_mul(sb, ta, prec, round_floor)
|
||||
b = mpf_mul(sb, tb, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = finf
|
||||
elif sbs <= 0:
|
||||
# negative * positive
|
||||
if tas >= 0:
|
||||
a = mpf_mul(sa, tb, prec, round_floor)
|
||||
b = mpf_mul(sb, ta, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = fzero
|
||||
# negative * negative
|
||||
elif tbs <= 0:
|
||||
a = mpf_mul(sb, tb, prec, round_floor)
|
||||
b = mpf_mul(sa, ta, prec, round_ceiling)
|
||||
if a == fnan: a = fzero
|
||||
if b == fnan: b = finf
|
||||
# negative * both signs
|
||||
else:
|
||||
a = mpf_mul(sa, tb, prec, round_floor)
|
||||
b = mpf_mul(sa, ta, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = finf
|
||||
else:
|
||||
# General case: perform all cross-multiplications and compare
|
||||
# Since the multiplications can be done exactly, we need only
|
||||
# do 4 (instead of 8: two for each rounding mode)
|
||||
cases = [mpf_mul(sa, ta), mpf_mul(sa, tb), mpf_mul(sb, ta), mpf_mul(sb, tb)]
|
||||
if fnan in cases:
|
||||
a, b = (fninf, finf)
|
||||
else:
|
||||
cases = sorted(cases, cmp=mpf_cmp)
|
||||
a = mpf_pos(cases[0], prec, round_floor)
|
||||
b = mpf_pos(cases[-1], prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_div(s, t, prec):
|
||||
sa, sb = s
|
||||
ta, tb = t
|
||||
sas = mpf_sign(sa)
|
||||
sbs = mpf_sign(sb)
|
||||
tas = mpf_sign(ta)
|
||||
tbs = mpf_sign(tb)
|
||||
# 0 / X
|
||||
if sas == sbs == 0:
|
||||
# 0 / <interval containing 0>
|
||||
if (tas < 0 and tbs > 0) or (tas == 0 or tbs == 0):
|
||||
return fninf, finf
|
||||
return fzero, fzero
|
||||
# Denominator contains both negative and positive numbers;
|
||||
# this should properly be a multi-interval, but the closest
|
||||
# match is the entire (extended) real line
|
||||
if tas < 0 and tbs > 0:
|
||||
return fninf, finf
|
||||
# Assume denominator to be nonnegative
|
||||
if tas < 0:
|
||||
return mpi_div(mpi_neg(s), mpi_neg(t), prec)
|
||||
# Division by zero
|
||||
# XXX: make sure all results make sense
|
||||
if tas == 0:
|
||||
# Numerator contains both signs?
|
||||
if sas < 0 and sbs > 0:
|
||||
return fninf, finf
|
||||
if tas == tbs:
|
||||
return fninf, finf
|
||||
# Numerator positive?
|
||||
if sas >= 0:
|
||||
a = mpf_div(sa, tb, prec, round_floor)
|
||||
b = finf
|
||||
if sbs <= 0:
|
||||
a = fninf
|
||||
b = mpf_div(sb, tb, prec, round_ceiling)
|
||||
# Division with positive denominator
|
||||
# We still have to handle nans resulting from inf/0 or inf/inf
|
||||
else:
|
||||
# Nonnegative numerator
|
||||
if sas >= 0:
|
||||
a = mpf_div(sa, tb, prec, round_floor)
|
||||
b = mpf_div(sb, ta, prec, round_ceiling)
|
||||
if a == fnan: a = fzero
|
||||
if b == fnan: b = finf
|
||||
# Nonpositive numerator
|
||||
elif sbs <= 0:
|
||||
a = mpf_div(sa, ta, prec, round_floor)
|
||||
b = mpf_div(sb, tb, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = fzero
|
||||
# Numerator contains both signs?
|
||||
else:
|
||||
a = mpf_div(sa, ta, prec, round_floor)
|
||||
b = mpf_div(sb, ta, prec, round_ceiling)
|
||||
if a == fnan: a = fninf
|
||||
if b == fnan: b = finf
|
||||
return a, b
|
||||
|
||||
def mpi_exp(s, prec):
|
||||
sa, sb = s
|
||||
# exp is monotonous
|
||||
a = mpf_exp(sa, prec, round_floor)
|
||||
b = mpf_exp(sb, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_log(s, prec):
|
||||
sa, sb = s
|
||||
# log is monotonous
|
||||
a = mpf_log(sa, prec, round_floor)
|
||||
b = mpf_log(sb, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_sqrt(s, prec):
|
||||
sa, sb = s
|
||||
# sqrt is monotonous
|
||||
a = mpf_sqrt(sa, prec, round_floor)
|
||||
b = mpf_sqrt(sb, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_pow_int(s, n, prec):
|
||||
sa, sb = s
|
||||
if n < 0:
|
||||
return mpi_div((fone, fone), mpi_pow_int(s, -n, prec+20), prec)
|
||||
if n == 0:
|
||||
return (fone, fone)
|
||||
if n == 1:
|
||||
return s
|
||||
# Odd -- signs are preserved
|
||||
if n & 1:
|
||||
a = mpf_pow_int(sa, n, prec, round_floor)
|
||||
b = mpf_pow_int(sb, n, prec, round_ceiling)
|
||||
# Even -- important to ensure positivity
|
||||
else:
|
||||
sas = mpf_sign(sa)
|
||||
sbs = mpf_sign(sb)
|
||||
# Nonnegative?
|
||||
if sas >= 0:
|
||||
a = mpf_pow_int(sa, n, prec, round_floor)
|
||||
b = mpf_pow_int(sb, n, prec, round_ceiling)
|
||||
# Nonpositive?
|
||||
elif sbs <= 0:
|
||||
a = mpf_pow_int(sb, n, prec, round_floor)
|
||||
b = mpf_pow_int(sa, n, prec, round_ceiling)
|
||||
# Mixed signs?
|
||||
else:
|
||||
a = fzero
|
||||
# max(-a,b)**n
|
||||
sa = mpf_neg(sa)
|
||||
if mpf_ge(sa, sb):
|
||||
b = mpf_pow_int(sa, n, prec, round_ceiling)
|
||||
else:
|
||||
b = mpf_pow_int(sb, n, prec, round_ceiling)
|
||||
return a, b
|
||||
|
||||
def mpi_pow(s, t, prec):
|
||||
ta, tb = t
|
||||
if ta == tb and ta not in (finf, fninf):
|
||||
if ta == from_int(to_int(ta)):
|
||||
return mpi_pow_int(s, to_int(ta), prec)
|
||||
if ta == fhalf:
|
||||
return mpi_sqrt(s, prec)
|
||||
u = mpi_log(s, prec + 20)
|
||||
v = mpi_mul(u, t, prec + 20)
|
||||
return mpi_exp(v, prec)
|
||||
|
||||
def MIN(x, y):
|
||||
if mpf_le(x, y):
|
||||
return x
|
||||
return y
|
||||
|
||||
def MAX(x, y):
|
||||
if mpf_ge(x, y):
|
||||
return x
|
||||
return y
|
||||
|
||||
def mpi_cos_sin(x, prec):
|
||||
a, b = x
|
||||
# Guaranteed to contain both -1 and 1
|
||||
if finf in (a, b) or fninf in (a, b):
|
||||
return (fnone, fone), (fnone, fone)
|
||||
y, yswaps, yn = reduce_angle(a, prec+20)
|
||||
z, zswaps, zn = reduce_angle(b, prec+20)
|
||||
# Guaranteed to contain both -1 and 1
|
||||
if zn - yn >= 4:
|
||||
return (fnone, fone), (fnone, fone)
|
||||
# Both points in the same quadrant -- cos and sin both strictly monotonous
|
||||
if yn == zn:
|
||||
m = yn % 4
|
||||
if m == 0:
|
||||
cb, sa = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_floor)
|
||||
ca, sb = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
|
||||
if m == 1:
|
||||
cb, sb = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_ceiling)
|
||||
ca, sa = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
|
||||
if m == 2:
|
||||
ca, sb = calc_cos_sin(0, y, yswaps, prec, round_floor, round_ceiling)
|
||||
cb, sa = calc_cos_sin(0, z, zswaps, prec, round_ceiling, round_floor)
|
||||
if m == 3:
|
||||
ca, sa = calc_cos_sin(0, y, yswaps, prec, round_floor, round_floor)
|
||||
cb, sb = calc_cos_sin(0, z, zswaps, prec, round_ceiling, round_ceiling)
|
||||
return (ca, cb), (sa, sb)
|
||||
# Intervals spanning multiple quadrants
|
||||
yn %= 4
|
||||
zn %= 4
|
||||
case = (yn, zn)
|
||||
if case == (0, 1):
|
||||
cb, sy = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_floor)
|
||||
ca, sz = calc_cos_sin(0, z, zswaps, prec, round_floor, round_floor)
|
||||
return (ca, cb), (MIN(sy, sz), fone)
|
||||
if case == (3, 0):
|
||||
cy, sa = calc_cos_sin(0, y, yswaps, prec, round_floor, round_floor)
|
||||
cz, sb = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
|
||||
return (MIN(cy, cz), fone), (sa, sb)
|
||||
|
||||
|
||||
raise NotImplementedError("cos/sin spanning multiple quadrants")
|
||||
|
||||
def mpi_cos(x, prec):
|
||||
return mpi_cos_sin(x, prec)[0]
|
||||
|
||||
def mpi_sin(x, prec):
|
||||
return mpi_cos_sin(x, prec)[1]
|
||||
|
||||
def mpi_tan(x, prec):
|
||||
cos, sin = mpi_cos_sin(x, prec+20)
|
||||
return mpi_div(sin, cos, prec)
|
||||
|
||||
def mpi_cot(x, prec):
|
||||
cos, sin = mpi_cos_sin(x, prec+20)
|
||||
return mpi_div(cos, sin, prec)
|
||||
|
|
@ -1,645 +0,0 @@
|
|||
"""
|
||||
This module complements the math and cmath builtin modules by providing
|
||||
fast machine precision versions of some additional functions (gamma, ...)
|
||||
and wrapping math/cmath functions so that they can be called with either
|
||||
real or complex arguments.
|
||||
"""
|
||||
|
||||
import operator
|
||||
import math
|
||||
import cmath
|
||||
|
||||
# Irrational (?) constants
|
||||
pi = 3.1415926535897932385
|
||||
e = 2.7182818284590452354
|
||||
sqrt2 = 1.4142135623730950488
|
||||
sqrt5 = 2.2360679774997896964
|
||||
phi = 1.6180339887498948482
|
||||
ln2 = 0.69314718055994530942
|
||||
ln10 = 2.302585092994045684
|
||||
euler = 0.57721566490153286061
|
||||
catalan = 0.91596559417721901505
|
||||
khinchin = 2.6854520010653064453
|
||||
apery = 1.2020569031595942854
|
||||
|
||||
logpi = 1.1447298858494001741
|
||||
|
||||
def _mathfun_real(f_real, f_complex):
|
||||
def f(x, **kwargs):
|
||||
if type(x) is float:
|
||||
return f_real(x)
|
||||
if type(x) is complex:
|
||||
return f_complex(x)
|
||||
try:
|
||||
x = float(x)
|
||||
return f_real(x)
|
||||
except (TypeError, ValueError):
|
||||
x = complex(x)
|
||||
return f_complex(x)
|
||||
f.__name__ = f_real.__name__
|
||||
return f
|
||||
|
||||
def _mathfun(f_real, f_complex):
|
||||
def f(x, **kwargs):
|
||||
if type(x) is complex:
|
||||
return f_complex(x)
|
||||
try:
|
||||
return f_real(float(x))
|
||||
except (TypeError, ValueError):
|
||||
return f_complex(complex(x))
|
||||
f.__name__ = f_real.__name__
|
||||
return f
|
||||
|
||||
def _mathfun_n(f_real, f_complex):
|
||||
def f(*args, **kwargs):
|
||||
try:
|
||||
return f_real(*(float(x) for x in args))
|
||||
except (TypeError, ValueError):
|
||||
return f_complex(*(complex(x) for x in args))
|
||||
f.__name__ = f_real.__name__
|
||||
return f
|
||||
|
||||
pow = _mathfun_n(operator.pow, lambda x, y: complex(x)**y)
|
||||
log = _mathfun_n(math.log, cmath.log)
|
||||
sqrt = _mathfun(math.sqrt, cmath.sqrt)
|
||||
exp = _mathfun_real(math.exp, cmath.exp)
|
||||
|
||||
cos = _mathfun_real(math.cos, cmath.cos)
|
||||
sin = _mathfun_real(math.sin, cmath.sin)
|
||||
tan = _mathfun_real(math.tan, cmath.tan)
|
||||
|
||||
acos = _mathfun(math.acos, cmath.acos)
|
||||
asin = _mathfun(math.asin, cmath.asin)
|
||||
atan = _mathfun_real(math.atan, cmath.atan)
|
||||
|
||||
cosh = _mathfun_real(math.cosh, cmath.cosh)
|
||||
sinh = _mathfun_real(math.sinh, cmath.sinh)
|
||||
tanh = _mathfun_real(math.tanh, cmath.tanh)
|
||||
|
||||
floor = _mathfun_real(math.floor,
|
||||
lambda z: complex(math.floor(z.real), math.floor(z.imag)))
|
||||
ceil = _mathfun_real(math.ceil,
|
||||
lambda z: complex(math.ceil(z.real), math.ceil(z.imag)))
|
||||
|
||||
|
||||
cos_sin = _mathfun_real(lambda x: (math.cos(x), math.sin(x)),
|
||||
lambda z: (cmath.cos(z), cmath.sin(z)))
|
||||
|
||||
cbrt = _mathfun(lambda x: x**(1./3), lambda z: z**(1./3))
|
||||
|
||||
def nthroot(x, n):
|
||||
r = 1./n
|
||||
try:
|
||||
return float(x) ** r
|
||||
except (ValueError, TypeError):
|
||||
return complex(x) ** r
|
||||
|
||||
def _sinpi_real(x):
|
||||
if x < 0:
|
||||
return -_sinpi_real(-x)
|
||||
n, r = divmod(x, 0.5)
|
||||
r *= pi
|
||||
n %= 4
|
||||
if n == 0: return math.sin(r)
|
||||
if n == 1: return math.cos(r)
|
||||
if n == 2: return -math.sin(r)
|
||||
if n == 3: return -math.cos(r)
|
||||
|
||||
def _cospi_real(x):
|
||||
if x < 0:
|
||||
x = -x
|
||||
n, r = divmod(x, 0.5)
|
||||
r *= pi
|
||||
n %= 4
|
||||
if n == 0: return math.cos(r)
|
||||
if n == 1: return -math.sin(r)
|
||||
if n == 2: return -math.cos(r)
|
||||
if n == 3: return math.sin(r)
|
||||
|
||||
def _sinpi_complex(z):
|
||||
if z.real < 0:
|
||||
return -_sinpi_complex(-z)
|
||||
n, r = divmod(z.real, 0.5)
|
||||
z = pi*complex(r, z.imag)
|
||||
n %= 4
|
||||
if n == 0: return cmath.sin(z)
|
||||
if n == 1: return cmath.cos(z)
|
||||
if n == 2: return -cmath.sin(z)
|
||||
if n == 3: return -cmath.cos(z)
|
||||
|
||||
def _cospi_complex(z):
|
||||
if z.real < 0:
|
||||
z = -z
|
||||
n, r = divmod(z.real, 0.5)
|
||||
z = pi*complex(r, z.imag)
|
||||
n %= 4
|
||||
if n == 0: return cmath.cos(z)
|
||||
if n == 1: return -cmath.sin(z)
|
||||
if n == 2: return -cmath.cos(z)
|
||||
if n == 3: return cmath.sin(z)
|
||||
|
||||
cospi = _mathfun_real(_cospi_real, _cospi_complex)
|
||||
sinpi = _mathfun_real(_sinpi_real, _sinpi_complex)
|
||||
|
||||
def tanpi(x):
|
||||
try:
|
||||
return sinpi(x) / cospi(x)
|
||||
except OverflowError:
|
||||
if complex(x).imag > 10:
|
||||
return 1j
|
||||
if complex(x).imag < 10:
|
||||
return -1j
|
||||
raise
|
||||
|
||||
def cotpi(x):
|
||||
try:
|
||||
return cospi(x) / sinpi(x)
|
||||
except OverflowError:
|
||||
if complex(x).imag > 10:
|
||||
return -1j
|
||||
if complex(x).imag < 10:
|
||||
return 1j
|
||||
raise
|
||||
|
||||
INF = 1e300*1e300
|
||||
NINF = -INF
|
||||
NAN = INF-INF
|
||||
EPS = 2.2204460492503131e-16
|
||||
|
||||
_exact_gamma = (INF, 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0,
|
||||
362880.0, 3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
|
||||
1307674368000.0, 20922789888000.0, 355687428096000.0, 6402373705728000.0,
|
||||
121645100408832000.0, 2432902008176640000.0)
|
||||
|
||||
_max_exact_gamma = len(_exact_gamma)-1
|
||||
|
||||
# Lanczos coefficients used by the GNU Scientific Library
|
||||
_lanczos_g = 7
|
||||
_lanczos_p = (0.99999999999980993, 676.5203681218851, -1259.1392167224028,
|
||||
771.32342877765313, -176.61502916214059, 12.507343278686905,
|
||||
-0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7)
|
||||
|
||||
def _gamma_real(x):
|
||||
_intx = int(x)
|
||||
if _intx == x:
|
||||
if _intx <= 0:
|
||||
#return (-1)**_intx * INF
|
||||
raise ZeroDivisionError("gamma function pole")
|
||||
if _intx <= _max_exact_gamma:
|
||||
return _exact_gamma[_intx]
|
||||
if x < 0.5:
|
||||
# TODO: sinpi
|
||||
return pi / (_sinpi_real(x)*_gamma_real(1-x))
|
||||
else:
|
||||
x -= 1.0
|
||||
r = _lanczos_p[0]
|
||||
for i in range(1, _lanczos_g+2):
|
||||
r += _lanczos_p[i]/(x+i)
|
||||
t = x + _lanczos_g + 0.5
|
||||
return 2.506628274631000502417 * t**(x+0.5) * math.exp(-t) * r
|
||||
|
||||
def _gamma_complex(x):
|
||||
if not x.imag:
|
||||
return complex(_gamma_real(x.real))
|
||||
if x.real < 0.5:
|
||||
# TODO: sinpi
|
||||
return pi / (_sinpi_complex(x)*_gamma_complex(1-x))
|
||||
else:
|
||||
x -= 1.0
|
||||
r = _lanczos_p[0]
|
||||
for i in range(1, _lanczos_g+2):
|
||||
r += _lanczos_p[i]/(x+i)
|
||||
t = x + _lanczos_g + 0.5
|
||||
return 2.506628274631000502417 * t**(x+0.5) * cmath.exp(-t) * r
|
||||
|
||||
gamma = _mathfun_real(_gamma_real, _gamma_complex)
|
||||
|
||||
def factorial(x):
|
||||
return gamma(x+1.0)
|
||||
|
||||
def arg(x):
|
||||
if type(x) is float:
|
||||
return math.atan2(0.0,x)
|
||||
return math.atan2(x.imag,x.real)
|
||||
|
||||
# XXX: broken for negatives
|
||||
def loggamma(x):
|
||||
if type(x) not in (float, complex):
|
||||
try:
|
||||
x = float(x)
|
||||
except (ValueError, TypeError):
|
||||
x = complex(x)
|
||||
try:
|
||||
xreal = x.real
|
||||
ximag = x.imag
|
||||
except AttributeError: # py2.5
|
||||
xreal = x
|
||||
ximag = 0.0
|
||||
# Reflection formula
|
||||
# http://functions.wolfram.com/GammaBetaErf/LogGamma/16/01/01/0003/
|
||||
if xreal < 0.0:
|
||||
if abs(x) < 0.5:
|
||||
v = log(gamma(x))
|
||||
if ximag == 0:
|
||||
v = v.conjugate()
|
||||
return v
|
||||
z = 1-x
|
||||
try:
|
||||
re = z.real
|
||||
im = z.imag
|
||||
except AttributeError: # py2.5
|
||||
re = z
|
||||
im = 0.0
|
||||
refloor = floor(re)
|
||||
imsign = cmp(im, 0)
|
||||
return (-pi*1j)*abs(refloor)*(1-abs(imsign)) + logpi - \
|
||||
log(sinpi(z-refloor)) - loggamma(z) + 1j*pi*refloor*imsign
|
||||
if x == 1.0 or x == 2.0:
|
||||
return x*0
|
||||
p = 0.
|
||||
while abs(x) < 11:
|
||||
p -= log(x)
|
||||
x += 1.0
|
||||
s = 0.918938533204672742 + (x-0.5)*log(x) - x
|
||||
r = 1./x
|
||||
r2 = r*r
|
||||
s += 0.083333333333333333333*r; r *= r2
|
||||
s += -0.0027777777777777777778*r; r *= r2
|
||||
s += 0.00079365079365079365079*r; r *= r2
|
||||
s += -0.0005952380952380952381*r; r *= r2
|
||||
s += 0.00084175084175084175084*r; r *= r2
|
||||
s += -0.0019175269175269175269*r; r *= r2
|
||||
s += 0.0064102564102564102564*r; r *= r2
|
||||
s += -0.02955065359477124183*r
|
||||
return s + p
|
||||
|
||||
_psi_coeff = [
|
||||
0.083333333333333333333,
|
||||
-0.0083333333333333333333,
|
||||
0.003968253968253968254,
|
||||
-0.0041666666666666666667,
|
||||
0.0075757575757575757576,
|
||||
-0.021092796092796092796,
|
||||
0.083333333333333333333,
|
||||
-0.44325980392156862745,
|
||||
3.0539543302701197438,
|
||||
-26.456212121212121212]
|
||||
|
||||
def _digamma_real(x):
|
||||
_intx = int(x)
|
||||
if _intx == x:
|
||||
if _intx <= 0:
|
||||
raise ZeroDivisionError("polygamma pole")
|
||||
if x < 0.5:
|
||||
x = 1.0-x
|
||||
s = pi*cotpi(x)
|
||||
else:
|
||||
s = 0.0
|
||||
while x < 10.0:
|
||||
s -= 1.0/x
|
||||
x += 1.0
|
||||
x2 = x**-2
|
||||
t = x2
|
||||
for c in _psi_coeff:
|
||||
s -= c*t
|
||||
if t < 1e-20:
|
||||
break
|
||||
t *= x2
|
||||
return s + math.log(x) - 0.5/x
|
||||
|
||||
def _digamma_complex(x):
|
||||
if not x.imag:
|
||||
return complex(_digamma_real(x.real))
|
||||
if x.real < 0.5:
|
||||
x = 1.0-x
|
||||
s = pi*cotpi(x)
|
||||
else:
|
||||
s = 0.0
|
||||
while abs(x) < 10.0:
|
||||
s -= 1.0/x
|
||||
x += 1.0
|
||||
x2 = x**-2
|
||||
t = x2
|
||||
for c in _psi_coeff:
|
||||
s -= c*t
|
||||
if abs(t) < 1e-20:
|
||||
break
|
||||
t *= x2
|
||||
return s + cmath.log(x) - 0.5/x
|
||||
|
||||
digamma = _mathfun_real(_digamma_real, _digamma_complex)
|
||||
|
||||
# TODO: could implement complex erf and erfc here. Need
|
||||
# to find an accurate method (avoiding cancellation)
|
||||
# for approx. 1 < abs(x) < 9.
|
||||
|
||||
_erfc_coeff_P = [
|
||||
1.0000000161203922312,
|
||||
2.1275306946297962644,
|
||||
2.2280433377390253297,
|
||||
1.4695509105618423961,
|
||||
0.66275911699770787537,
|
||||
0.20924776504163751585,
|
||||
0.045459713768411264339,
|
||||
0.0063065951710717791934,
|
||||
0.00044560259661560421715][::-1]
|
||||
|
||||
_erfc_coeff_Q = [
|
||||
1.0000000000000000000,
|
||||
3.2559100272784894318,
|
||||
4.9019435608903239131,
|
||||
4.4971472894498014205,
|
||||
2.7845640601891186528,
|
||||
1.2146026030046904138,
|
||||
0.37647108453729465912,
|
||||
0.080970149639040548613,
|
||||
0.011178148899483545902,
|
||||
0.00078981003831980423513][::-1]
|
||||
|
||||
def _polyval(coeffs, x):
|
||||
p = coeffs[0]
|
||||
for c in coeffs[1:]:
|
||||
p = c + x*p
|
||||
return p
|
||||
|
||||
def _erf_taylor(x):
|
||||
# Taylor series assuming 0 <= x <= 1
|
||||
x2 = x*x
|
||||
s = t = x
|
||||
n = 1
|
||||
while abs(t) > 1e-17:
|
||||
t *= x2/n
|
||||
s -= t/(n+n+1)
|
||||
n += 1
|
||||
t *= x2/n
|
||||
s += t/(n+n+1)
|
||||
n += 1
|
||||
return 1.1283791670955125739*s
|
||||
|
||||
def _erfc_mid(x):
|
||||
# Rational approximation assuming 0 <= x <= 9
|
||||
return exp(-x*x)*_polyval(_erfc_coeff_P,x)/_polyval(_erfc_coeff_Q,x)
|
||||
|
||||
def _erfc_asymp(x):
|
||||
# Asymptotic expansion assuming x >= 9
|
||||
x2 = x*x
|
||||
v = exp(-x2)/x*0.56418958354775628695
|
||||
r = t = 0.5 / x2
|
||||
s = 1.0
|
||||
for n in range(1,22,4):
|
||||
s -= t
|
||||
t *= r * (n+2)
|
||||
s += t
|
||||
t *= r * (n+4)
|
||||
if abs(t) < 1e-17:
|
||||
break
|
||||
return s * v
|
||||
|
||||
def erf(x):
|
||||
"""
|
||||
erf of a real number.
|
||||
"""
|
||||
x = float(x)
|
||||
if x != x:
|
||||
return x
|
||||
if x < 0.0:
|
||||
return -erf(-x)
|
||||
if x >= 1.0:
|
||||
if x >= 6.0:
|
||||
return 1.0
|
||||
return 1.0 - _erfc_mid(x)
|
||||
return _erf_taylor(x)
|
||||
|
||||
def erfc(x):
|
||||
"""
|
||||
erfc of a real number.
|
||||
"""
|
||||
x = float(x)
|
||||
if x != x:
|
||||
return x
|
||||
if x < 0.0:
|
||||
if x < -6.0:
|
||||
return 2.0
|
||||
return 2.0-erfc(-x)
|
||||
if x > 9.0:
|
||||
return _erfc_asymp(x)
|
||||
if x >= 1.0:
|
||||
return _erfc_mid(x)
|
||||
return 1.0 - _erf_taylor(x)
|
||||
|
||||
gauss42 = [\
|
||||
(0.99839961899006235, 0.0041059986046490839),
|
||||
(-0.99839961899006235, 0.0041059986046490839),
|
||||
(0.9915772883408609, 0.009536220301748501),
|
||||
(-0.9915772883408609,0.009536220301748501),
|
||||
(0.97934250806374812, 0.014922443697357493),
|
||||
(-0.97934250806374812, 0.014922443697357493),
|
||||
(0.96175936533820439,0.020227869569052644),
|
||||
(-0.96175936533820439, 0.020227869569052644),
|
||||
(0.93892355735498811, 0.025422959526113047),
|
||||
(-0.93892355735498811,0.025422959526113047),
|
||||
(0.91095972490412735, 0.030479240699603467),
|
||||
(-0.91095972490412735, 0.030479240699603467),
|
||||
(0.87802056981217269,0.03536907109759211),
|
||||
(-0.87802056981217269, 0.03536907109759211),
|
||||
(0.8402859832618168, 0.040065735180692258),
|
||||
(-0.8402859832618168,0.040065735180692258),
|
||||
(0.7979620532554873, 0.044543577771965874),
|
||||
(-0.7979620532554873, 0.044543577771965874),
|
||||
(0.75127993568948048,0.048778140792803244),
|
||||
(-0.75127993568948048, 0.048778140792803244),
|
||||
(0.70049459055617114, 0.052746295699174064),
|
||||
(-0.70049459055617114,0.052746295699174064),
|
||||
(0.64588338886924779, 0.056426369358018376),
|
||||
(-0.64588338886924779, 0.056426369358018376),
|
||||
(0.58774459748510932, 0.059798262227586649),
|
||||
(-0.58774459748510932, 0.059798262227586649),
|
||||
(0.5263957499311922, 0.062843558045002565),
|
||||
(-0.5263957499311922, 0.062843558045002565),
|
||||
(0.46217191207042191, 0.065545624364908975),
|
||||
(-0.46217191207042191, 0.065545624364908975),
|
||||
(0.39542385204297503, 0.067889703376521934),
|
||||
(-0.39542385204297503, 0.067889703376521934),
|
||||
(0.32651612446541151, 0.069862992492594159),
|
||||
(-0.32651612446541151, 0.069862992492594159),
|
||||
(0.25582507934287907, 0.071454714265170971),
|
||||
(-0.25582507934287907, 0.071454714265170971),
|
||||
(0.18373680656485453, 0.072656175243804091),
|
||||
(-0.18373680656485453, 0.072656175243804091),
|
||||
(0.11064502720851986, 0.073460813453467527),
|
||||
(-0.11064502720851986, 0.073460813453467527),
|
||||
(0.036948943165351772, 0.073864234232172879),
|
||||
(-0.036948943165351772, 0.073864234232172879)]
|
||||
|
||||
EI_ASYMP_CONVERGENCE_RADIUS = 40.0
|
||||
|
||||
def ei_asymp(z, _e1=False):
|
||||
r = 1./z
|
||||
s = t = 1.0
|
||||
k = 1
|
||||
while 1:
|
||||
t *= k*r
|
||||
s += t
|
||||
if abs(t) < 1e-16:
|
||||
break
|
||||
k += 1
|
||||
v = s*exp(z)/z
|
||||
if _e1:
|
||||
if type(z) is complex:
|
||||
zreal = z.real
|
||||
zimag = z.imag
|
||||
else:
|
||||
zreal = z
|
||||
zimag = 0.0
|
||||
if zimag == 0.0 and zreal > 0.0:
|
||||
v += pi*1j
|
||||
else:
|
||||
if type(z) is complex:
|
||||
if z.imag > 0:
|
||||
v += pi*1j
|
||||
if z.imag < 0:
|
||||
v -= pi*1j
|
||||
return v
|
||||
|
||||
def ei_taylor(z, _e1=False):
|
||||
s = t = z
|
||||
k = 2
|
||||
while 1:
|
||||
t = t*z/k
|
||||
term = t/k
|
||||
if abs(term) < 1e-17:
|
||||
break
|
||||
s += term
|
||||
k += 1
|
||||
s += euler
|
||||
if _e1:
|
||||
s += log(-z)
|
||||
else:
|
||||
if type(z) is float or z.imag == 0.0:
|
||||
s += math.log(abs(z))
|
||||
else:
|
||||
s += cmath.log(z)
|
||||
return s
|
||||
|
||||
def ei(z, _e1=False):
|
||||
typez = type(z)
|
||||
if typez not in (float, complex):
|
||||
try:
|
||||
z = float(z)
|
||||
typez = float
|
||||
except (TypeError, ValueError):
|
||||
z = complex(z)
|
||||
typez = complex
|
||||
if not z:
|
||||
return -INF
|
||||
absz = abs(z)
|
||||
if absz > EI_ASYMP_CONVERGENCE_RADIUS:
|
||||
return ei_asymp(z, _e1)
|
||||
elif absz <= 2.0 or (typez is float and z > 0.0):
|
||||
return ei_taylor(z, _e1)
|
||||
# Integrate, starting from whichever is smaller of a Taylor
|
||||
# series value or an asymptotic series value
|
||||
if typez is complex and z.real > 0.0:
|
||||
zref = z / absz
|
||||
ref = ei_taylor(zref, _e1)
|
||||
else:
|
||||
zref = EI_ASYMP_CONVERGENCE_RADIUS * z / absz
|
||||
ref = ei_asymp(zref, _e1)
|
||||
C = (zref-z)*0.5
|
||||
D = (zref+z)*0.5
|
||||
s = 0.0
|
||||
if type(z) is complex:
|
||||
_exp = cmath.exp
|
||||
else:
|
||||
_exp = math.exp
|
||||
for x,w in gauss42:
|
||||
t = C*x+D
|
||||
s += w*_exp(t)/t
|
||||
ref -= C*s
|
||||
return ref
|
||||
|
||||
def e1(z):
|
||||
# hack to get consistent signs if the imaginary part if 0
|
||||
# and signed
|
||||
typez = type(z)
|
||||
if type(z) not in (float, complex):
|
||||
try:
|
||||
z = float(z)
|
||||
typez = float
|
||||
except (TypeError, ValueError):
|
||||
z = complex(z)
|
||||
typez = complex
|
||||
if typez is complex and not z.imag:
|
||||
z = complex(z.real, 0.0)
|
||||
# end hack
|
||||
return -ei(-z, _e1=True)
|
||||
|
||||
_zeta_int = [\
|
||||
-0.5,
|
||||
0.0,
|
||||
1.6449340668482264365,1.2020569031595942854,1.0823232337111381915,
|
||||
1.0369277551433699263,1.0173430619844491397,1.0083492773819228268,
|
||||
1.0040773561979443394,1.0020083928260822144,1.0009945751278180853,
|
||||
1.0004941886041194646,1.0002460865533080483,1.0001227133475784891,
|
||||
1.0000612481350587048,1.0000305882363070205,1.0000152822594086519,
|
||||
1.0000076371976378998,1.0000038172932649998,1.0000019082127165539,
|
||||
1.0000009539620338728,1.0000004769329867878,1.0000002384505027277,
|
||||
1.0000001192199259653,1.0000000596081890513,1.0000000298035035147,
|
||||
1.0000000149015548284]
|
||||
|
||||
_zeta_P = [-3.50000000087575873, -0.701274355654678147,
|
||||
-0.0672313458590012612, -0.00398731457954257841,
|
||||
-0.000160948723019303141, -4.67633010038383371e-6,
|
||||
-1.02078104417700585e-7, -1.68030037095896287e-9,
|
||||
-1.85231868742346722e-11][::-1]
|
||||
|
||||
_zeta_Q = [1.00000000000000000, -0.936552848762465319,
|
||||
-0.0588835413263763741, -0.00441498861482948666,
|
||||
-0.000143416758067432622, -5.10691659585090782e-6,
|
||||
-9.58813053268913799e-8, -1.72963791443181972e-9,
|
||||
-1.83527919681474132e-11][::-1]
|
||||
|
||||
_zeta_1 = [3.03768838606128127e-10, -1.21924525236601262e-8,
|
||||
2.01201845887608893e-7, -1.53917240683468381e-6,
|
||||
-5.09890411005967954e-7, 0.000122464707271619326,
|
||||
-0.000905721539353130232, -0.00239315326074843037,
|
||||
0.084239750013159168, 0.418938517907442414, 0.500000001921884009]
|
||||
|
||||
_zeta_0 = [-3.46092485016748794e-10, -6.42610089468292485e-9,
|
||||
1.76409071536679773e-7, -1.47141263991560698e-6, -6.38880222546167613e-7,
|
||||
0.000122641099800668209, -0.000905894913516772796, -0.00239303348507992713,
|
||||
0.0842396947501199816, 0.418938533204660256, 0.500000000000000052]
|
||||
|
||||
def zeta(s):
|
||||
"""
|
||||
Riemann zeta function, real argument
|
||||
"""
|
||||
if not isinstance(s, (float, int)):
|
||||
try:
|
||||
s = float(s)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
s = complex(s)
|
||||
if not s.imag:
|
||||
return complex(zeta(s.real))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
if s == 1:
|
||||
raise ValueError("zeta(1) pole")
|
||||
if s >= 27:
|
||||
return 1.0 + 2.0**(-s) + 3.0**(-s)
|
||||
n = int(s)
|
||||
if n == s:
|
||||
if n >= 0:
|
||||
return _zeta_int[n]
|
||||
if not (n % 2):
|
||||
return 0.0
|
||||
if s <= 0.0:
|
||||
return 2.**s*pi**(s-1)*_sinpi_real(0.5*s)*_gamma_real(1-s)*zeta(1-s)
|
||||
if s <= 2.0:
|
||||
if s <= 1.0:
|
||||
return _polyval(_zeta_0,s)/(s-1)
|
||||
return _polyval(_zeta_1,s)/(s-1)
|
||||
z = _polyval(_zeta_P,s) / _polyval(_zeta_Q,s)
|
||||
return 1.0 + 2.0**(-s) + 3.0**(-s) + 4.0**(-s)*z
|
||||
|
|
@ -1,522 +0,0 @@
|
|||
# TODO: should use diagonalization-based algorithms
|
||||
|
||||
class MatrixCalculusMethods:
|
||||
|
||||
def _exp_pade(ctx, a):
|
||||
"""
|
||||
Exponential of a matrix using Pade approximants.
|
||||
|
||||
See G. H. Golub, C. F. van Loan 'Matrix Computations',
|
||||
third Ed., page 572
|
||||
|
||||
TODO:
|
||||
- find a good estimate for q
|
||||
- reduce the number of matrix multiplications to improve
|
||||
performance
|
||||
"""
|
||||
def eps_pade(p):
|
||||
return ctx.mpf(2)**(3-2*p) * \
|
||||
ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
|
||||
q = 4
|
||||
extraq = 8
|
||||
while 1:
|
||||
if eps_pade(q) < ctx.eps:
|
||||
break
|
||||
q += 1
|
||||
q += extraq
|
||||
j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
|
||||
extra = q
|
||||
prec = ctx.prec
|
||||
ctx.dps += extra + 3
|
||||
try:
|
||||
a = a/2**j
|
||||
na = a.rows
|
||||
den = ctx.eye(na)
|
||||
num = ctx.eye(na)
|
||||
x = ctx.eye(na)
|
||||
c = ctx.mpf(1)
|
||||
for k in range(1, q+1):
|
||||
c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
|
||||
x = a*x
|
||||
cx = c*x
|
||||
num += cx
|
||||
den += (-1)**k * cx
|
||||
f = ctx.lu_solve_mat(den, num)
|
||||
for k in range(j):
|
||||
f = f*f
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return f*1
|
||||
|
||||
def expm(ctx, A, method='taylor'):
|
||||
r"""
|
||||
Computes the matrix exponential of a square matrix `A`, which is defined
|
||||
by the power series
|
||||
|
||||
.. math ::
|
||||
|
||||
\exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
|
||||
|
||||
With method='taylor', the matrix exponential is computed
|
||||
using the Taylor series. With method='pade', Pade approximants
|
||||
are used instead.
|
||||
|
||||
**Examples**
|
||||
|
||||
Basic examples::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> expm(zeros(3))
|
||||
[1.0 0.0 0.0]
|
||||
[0.0 1.0 0.0]
|
||||
[0.0 0.0 1.0]
|
||||
>>> expm(eye(3))
|
||||
[2.71828182845905 0.0 0.0]
|
||||
[ 0.0 2.71828182845905 0.0]
|
||||
[ 0.0 0.0 2.71828182845905]
|
||||
>>> expm([[1,1,0],[1,0,1],[0,1,0]])
|
||||
[ 3.86814500615414 2.26812870852145 0.841130841230196]
|
||||
[ 2.26812870852145 2.44114713886289 1.42699786729125]
|
||||
[0.841130841230196 1.42699786729125 1.6000162976327]
|
||||
>>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
|
||||
[ 3.86814500615414 2.26812870852145 0.841130841230196]
|
||||
[ 2.26812870852145 2.44114713886289 1.42699786729125]
|
||||
[0.841130841230196 1.42699786729125 1.6000162976327]
|
||||
>>> expm([[1+j, 0], [1+j,1]])
|
||||
[(1.46869393991589 + 2.28735528717884j) 0.0]
|
||||
[ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
|
||||
|
||||
Matrices with large entries are allowed::
|
||||
|
||||
>>> expm(matrix([[1,2],[2,3]])**25)
|
||||
[5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
|
||||
[9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
|
||||
|
||||
The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
|
||||
noncommuting matrices::
|
||||
|
||||
>>> A = hilbert(3)
|
||||
>>> B = A + eye(3)
|
||||
>>> chop(mnorm(A*B - B*A))
|
||||
0.0
|
||||
>>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
|
||||
0.0
|
||||
>>> B = A + ones(3)
|
||||
>>> mnorm(A*B - B*A)
|
||||
1.8
|
||||
>>> mnorm(expm(A+B) - expm(A)*expm(B))
|
||||
42.0927851137247
|
||||
|
||||
"""
|
||||
A = ctx.matrix(A)
|
||||
if method == 'pade':
|
||||
return ctx._exp_pade(A)
|
||||
prec = ctx.prec
|
||||
j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
|
||||
j += int(0.5*prec**0.5)
|
||||
try:
|
||||
ctx.prec += 10 + 2*j
|
||||
tol = +ctx.eps
|
||||
A = A/2**j
|
||||
T = A
|
||||
Y = A**0 + A
|
||||
k = 2
|
||||
while 1:
|
||||
T *= A * (1/ctx.mpf(k))
|
||||
if ctx.mnorm(T, 'inf') < tol:
|
||||
break
|
||||
Y += T
|
||||
k += 1
|
||||
for k in xrange(j):
|
||||
Y = Y*Y
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
Y *= 1
|
||||
return Y
|
||||
|
||||
def cosm(ctx, A):
|
||||
r"""
|
||||
Gives the cosine of a square matrix `A`, defined in analogy
|
||||
with the matrix exponential.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> X = eye(3)
|
||||
>>> cosm(X)
|
||||
[0.54030230586814 0.0 0.0]
|
||||
[ 0.0 0.54030230586814 0.0]
|
||||
[ 0.0 0.0 0.54030230586814]
|
||||
>>> X = hilbert(3)
|
||||
>>> cosm(X)
|
||||
[ 0.424403834569555 -0.316643413047167 -0.221474945949293]
|
||||
[-0.316643413047167 0.820646708837824 -0.127183694770039]
|
||||
[-0.221474945949293 -0.127183694770039 0.909236687217541]
|
||||
>>> X = matrix([[1+j,-2],[0,-j]])
|
||||
>>> cosm(X)
|
||||
[(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
|
||||
[ 0.0 (1.54308063481524 + 0.0j)]
|
||||
"""
|
||||
B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
|
||||
if not sum(A.apply(ctx.im).apply(abs)):
|
||||
B = B.apply(ctx.re)
|
||||
return B
|
||||
|
||||
def sinm(ctx, A):
|
||||
r"""
|
||||
Gives the sine of a square matrix `A`, defined in analogy
|
||||
with the matrix exponential.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> X = eye(3)
|
||||
>>> sinm(X)
|
||||
[0.841470984807897 0.0 0.0]
|
||||
[ 0.0 0.841470984807897 0.0]
|
||||
[ 0.0 0.0 0.841470984807897]
|
||||
>>> X = hilbert(3)
|
||||
>>> sinm(X)
|
||||
[0.711608512150994 0.339783913247439 0.220742837314741]
|
||||
[0.339783913247439 0.244113865695532 0.187231271174372]
|
||||
[0.220742837314741 0.187231271174372 0.155816730769635]
|
||||
>>> X = matrix([[1+j,-2],[0,-j]])
|
||||
>>> sinm(X)
|
||||
[(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
|
||||
[ 0.0 (0.0 - 1.1752011936438j)]
|
||||
"""
|
||||
B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
|
||||
if not sum(A.apply(ctx.im).apply(abs)):
|
||||
B = B.apply(ctx.re)
|
||||
return B
|
||||
|
||||
def _sqrtm_rot(ctx, A, _may_rotate):
|
||||
# If the iteration fails to converge, cheat by performing
|
||||
# a rotation by a complex number
|
||||
u = ctx.j**0.3
|
||||
return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
|
||||
|
||||
def sqrtm(ctx, A, _may_rotate=2):
|
||||
r"""
|
||||
Computes a square root of the square matrix `A`, i.e. returns
|
||||
a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
|
||||
of a matrix, if it exists, is not unique.
|
||||
|
||||
**Examples**
|
||||
|
||||
Square roots of some simple matrices::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> sqrtm([[1,0], [0,1]])
|
||||
[1.0 0.0]
|
||||
[0.0 1.0]
|
||||
>>> sqrtm([[0,0], [0,0]])
|
||||
[0.0 0.0]
|
||||
[0.0 0.0]
|
||||
>>> sqrtm([[2,0],[0,1]])
|
||||
[1.4142135623731 0.0]
|
||||
[ 0.0 1.0]
|
||||
>>> sqrtm([[1,1],[1,0]])
|
||||
[ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
|
||||
[(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
|
||||
>>> sqrtm([[1,0],[0,1]])
|
||||
[1.0 0.0]
|
||||
[0.0 1.0]
|
||||
>>> sqrtm([[-1,0],[0,1]])
|
||||
[(0.0 - 1.0j) 0.0]
|
||||
[ 0.0 (1.0 + 0.0j)]
|
||||
>>> sqrtm([[j,0],[0,j]])
|
||||
[(0.707106781186547 + 0.707106781186547j) 0.0]
|
||||
[ 0.0 (0.707106781186547 + 0.707106781186547j)]
|
||||
|
||||
A square root of a rotation matrix, giving the corresponding
|
||||
half-angle rotation matrix::
|
||||
|
||||
>>> t1 = 0.75
|
||||
>>> t2 = t1 * 0.5
|
||||
>>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
|
||||
>>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
|
||||
>>> sqrtm(A1)
|
||||
[0.930507621912314 -0.366272529086048]
|
||||
[0.366272529086048 0.930507621912314]
|
||||
>>> A2
|
||||
[0.930507621912314 -0.366272529086048]
|
||||
[0.366272529086048 0.930507621912314]
|
||||
|
||||
The identity `(A^2)^{1/2} = A` does not necessarily hold::
|
||||
|
||||
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
|
||||
>>> sqrtm(A**2)
|
||||
[ 4.0 1.0 4.0]
|
||||
[ 7.0 8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
>>> sqrtm(A)**2
|
||||
[ 4.0 1.0 4.0]
|
||||
[ 7.0 8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
>>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
|
||||
>>> sqrtm(A**2)
|
||||
[ 7.43715112194995 -0.324127569985474 1.8481718827526]
|
||||
[-0.251549715716942 9.32699765900402 2.48221180985147]
|
||||
[ 4.11609388833616 0.775751877098258 13.017955697342]
|
||||
>>> chop(sqrtm(A)**2)
|
||||
[-4.0 1.0 4.0]
|
||||
[ 7.0 -8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
|
||||
For some matrices, a square root does not exist::
|
||||
|
||||
>>> sqrtm([[0,1], [0,0]])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ZeroDivisionError: matrix is numerically singular
|
||||
|
||||
Two examples from the documentation for Matlab's ``sqrtm``::
|
||||
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> sqrtm([[7,10],[15,22]])
|
||||
[1.56669890360128 1.74077655955698]
|
||||
[2.61116483933547 4.17786374293675]
|
||||
>>>
|
||||
>>> X = matrix(\
|
||||
... [[5,-4,1,0,0],
|
||||
... [-4,6,-4,1,0],
|
||||
... [1,-4,6,-4,1],
|
||||
... [0,1,-4,6,-4],
|
||||
... [0,0,1,-4,5]])
|
||||
>>> Y = matrix(\
|
||||
... [[2,-1,-0,-0,-0],
|
||||
... [-1,2,-1,0,-0],
|
||||
... [0,-1,2,-1,0],
|
||||
... [-0,0,-1,2,-1],
|
||||
... [-0,-0,-0,-1,2]])
|
||||
>>> mnorm(sqrtm(X) - Y)
|
||||
4.53155328326114e-19
|
||||
|
||||
"""
|
||||
A = ctx.matrix(A)
|
||||
# Trivial
|
||||
if A*0 == A:
|
||||
return A
|
||||
prec = ctx.prec
|
||||
if _may_rotate:
|
||||
d = ctx.det(A)
|
||||
if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
|
||||
return ctx._sqrtm_rot(A, _may_rotate-1)
|
||||
try:
|
||||
ctx.prec += 10
|
||||
tol = ctx.eps * 128
|
||||
Y = A
|
||||
Z = I = A**0
|
||||
k = 0
|
||||
# Denman-Beavers iteration
|
||||
while 1:
|
||||
Yprev = Y
|
||||
try:
|
||||
Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
|
||||
except ZeroDivisionError:
|
||||
if _may_rotate:
|
||||
Y = ctx._sqrtm_rot(A, _may_rotate-1)
|
||||
break
|
||||
else:
|
||||
raise
|
||||
mag1 = ctx.mnorm(Y-Yprev, 'inf')
|
||||
mag2 = ctx.mnorm(Y, 'inf')
|
||||
if mag1 <= mag2*tol:
|
||||
break
|
||||
if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
|
||||
return ctx._sqrtm_rot(A, _may_rotate-1)
|
||||
k += 1
|
||||
if k > ctx.prec:
|
||||
raise ctx.NoConvergence
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
Y *= 1
|
||||
return Y
|
||||
|
||||
def logm(ctx, A):
|
||||
r"""
|
||||
Computes a logarithm of the square matrix `A`, i.e. returns
|
||||
a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
|
||||
of a matrix, if it exists, is not unique.
|
||||
|
||||
**Examples**
|
||||
|
||||
Logarithms of some simple matrices::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> X = eye(3)
|
||||
>>> logm(X)
|
||||
[0.0 0.0 0.0]
|
||||
[0.0 0.0 0.0]
|
||||
[0.0 0.0 0.0]
|
||||
>>> logm(2*X)
|
||||
[0.693147180559945 0.0 0.0]
|
||||
[ 0.0 0.693147180559945 0.0]
|
||||
[ 0.0 0.0 0.693147180559945]
|
||||
>>> logm(expm(X))
|
||||
[1.0 0.0 0.0]
|
||||
[0.0 1.0 0.0]
|
||||
[0.0 0.0 1.0]
|
||||
|
||||
A logarithm of a complex matrix::
|
||||
|
||||
>>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
|
||||
>>> B = logm(X)
|
||||
>>> nprint(B)
|
||||
[ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
|
||||
[ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
|
||||
[(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
|
||||
>>> chop(expm(B))
|
||||
[(2.0 + 1.0j) 1.0 3.0]
|
||||
[(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
|
||||
[ -4.0 -5.0 (0.0 + 1.0j)]
|
||||
|
||||
A matrix `X` close to the identity matrix, for which
|
||||
`\log(\exp(X)) = \exp(\log(X)) = X` holds::
|
||||
|
||||
>>> X = eye(3) + hilbert(3)/4
|
||||
>>> X
|
||||
[ 1.25 0.125 0.0833333333333333]
|
||||
[ 0.125 1.08333333333333 0.0625]
|
||||
[0.0833333333333333 0.0625 1.05]
|
||||
>>> logm(expm(X))
|
||||
[ 1.25 0.125 0.0833333333333333]
|
||||
[ 0.125 1.08333333333333 0.0625]
|
||||
[0.0833333333333333 0.0625 1.05]
|
||||
>>> expm(logm(X))
|
||||
[ 1.25 0.125 0.0833333333333333]
|
||||
[ 0.125 1.08333333333333 0.0625]
|
||||
[0.0833333333333333 0.0625 1.05]
|
||||
|
||||
A logarithm of a rotation matrix, giving back the angle of
|
||||
the rotation::
|
||||
|
||||
>>> t = 3.7
|
||||
>>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
|
||||
>>> chop(logm(A))
|
||||
[ 0.0 -2.58318530717959]
|
||||
[2.58318530717959 0.0]
|
||||
>>> (2*pi-t)
|
||||
2.58318530717959
|
||||
|
||||
For some matrices, a logarithm does not exist::
|
||||
|
||||
>>> logm([[1,0], [0,0]])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ZeroDivisionError: matrix is numerically singular
|
||||
|
||||
Logarithm of a matrix with large entries::
|
||||
|
||||
>>> logm(hilbert(3) * 10**20).apply(re)
|
||||
[ 45.5597513593433 1.27721006042799 0.317662687717978]
|
||||
[ 1.27721006042799 42.5222778973542 2.24003708791604]
|
||||
[0.317662687717978 2.24003708791604 42.395212822267]
|
||||
|
||||
"""
|
||||
A = ctx.matrix(A)
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
tol = ctx.eps * 128
|
||||
I = A**0
|
||||
B = A
|
||||
n = 0
|
||||
while 1:
|
||||
B = ctx.sqrtm(B)
|
||||
n += 1
|
||||
if ctx.mnorm(B-I, 'inf') < 0.125:
|
||||
break
|
||||
T = X = B-I
|
||||
L = X*0
|
||||
k = 1
|
||||
while 1:
|
||||
if k & 1:
|
||||
L += T / k
|
||||
else:
|
||||
L -= T / k
|
||||
T *= X
|
||||
if ctx.mnorm(T, 'inf') < tol:
|
||||
break
|
||||
k += 1
|
||||
if k > ctx.prec:
|
||||
raise ctx.NoConvergence
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
L *= 2**n
|
||||
return L
|
||||
|
||||
def powm(ctx, A, r):
|
||||
r"""
|
||||
Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
|
||||
number `r`.
|
||||
|
||||
**Examples**
|
||||
|
||||
Powers and inverse powers of a matrix::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = True
|
||||
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
|
||||
>>> powm(A, 2)
|
||||
[ 63.0 20.0 69.0]
|
||||
[174.0 89.0 199.0]
|
||||
[164.0 48.0 179.0]
|
||||
>>> chop(powm(powm(A, 4), 1/4.))
|
||||
[ 4.0 1.0 4.0]
|
||||
[ 7.0 8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
>>> powm(extraprec(20)(powm)(A, -4), -1/4.)
|
||||
[ 4.0 1.0 4.0]
|
||||
[ 7.0 8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
>>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
|
||||
[ 4.0 1.0 4.0]
|
||||
[ 7.0 8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
>>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
|
||||
[ 4.0 1.0 4.0]
|
||||
[ 7.0 8.0 9.0]
|
||||
[10.0 2.0 11.0]
|
||||
|
||||
A Fibonacci-generating matrix::
|
||||
|
||||
>>> powm([[1,1],[1,0]], 10)
|
||||
[89.0 55.0]
|
||||
[55.0 34.0]
|
||||
>>> fib(10)
|
||||
55.0
|
||||
>>> powm([[1,1],[1,0]], 6.5)
|
||||
[(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
|
||||
[(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
|
||||
>>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
|
||||
(10.2078589271083 - 0.0195927472575932j)
|
||||
>>> powm([[1,1],[1,0]], 6.2)
|
||||
[ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
|
||||
[(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
|
||||
>>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
|
||||
(8.81733464837593 - 0.0133048601383712j)
|
||||
|
||||
"""
|
||||
A = ctx.matrix(A)
|
||||
r = ctx.convert(r)
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
if ctx.isint(r):
|
||||
v = A ** int(r)
|
||||
elif ctx.isint(r*2):
|
||||
y = int(r*2)
|
||||
v = ctx.sqrtm(A) ** y
|
||||
else:
|
||||
v = ctx.expm(r*ctx.logm(A))
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
v *= 1
|
||||
return v
|
||||
|
|
@ -1,516 +0,0 @@
|
|||
"""
|
||||
Linear algebra
|
||||
--------------
|
||||
|
||||
Linear equations
|
||||
................
|
||||
|
||||
Basic linear algebra is implemented; you can for example solve the linear
|
||||
equation system::
|
||||
|
||||
x + 2*y = -10
|
||||
3*x + 4*y = 10
|
||||
|
||||
using ``lu_solve``::
|
||||
|
||||
>>> A = matrix([[1, 2], [3, 4]])
|
||||
>>> b = matrix([-10, 10])
|
||||
>>> x = lu_solve(A, b)
|
||||
>>> x
|
||||
matrix(
|
||||
[['30.0'],
|
||||
['-20.0']])
|
||||
|
||||
If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
|
||||
|
||||
>>> residual(A, x, b)
|
||||
matrix(
|
||||
[['3.46944695195361e-18'],
|
||||
['3.46944695195361e-18']])
|
||||
>>> str(eps)
|
||||
'2.22044604925031e-16'
|
||||
|
||||
As you can see, the solution is quite accurate. The error is caused by the
|
||||
inaccuracy of the internal floating point arithmetic. Though, it's even smaller
|
||||
than the current machine epsilon, which basically means you can trust the
|
||||
result.
|
||||
|
||||
If you need more speed, use NumPy. Or choose a faster data type using the
|
||||
keyword ``force_type``::
|
||||
|
||||
>>> lu_solve(A, b, force_type=float)
|
||||
matrix(
|
||||
[[29.999999999999996],
|
||||
[-19.999999999999996]])
|
||||
|
||||
``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
|
||||
such systems, so the residual is minimized instead. Internally this is done
|
||||
using Cholesky decomposition to compute a least squares approximation. This means
|
||||
that that ``lu_solve`` will square the errors. If you can't afford this, use
|
||||
``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
|
||||
the residual automatically.
|
||||
|
||||
|
||||
Matrix factorization
|
||||
....................
|
||||
|
||||
The function ``lu`` computes an explicit LU factorization of a matrix::
|
||||
|
||||
>>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
|
||||
>>> print P
|
||||
[0.0 0.0 1.0]
|
||||
[1.0 0.0 0.0]
|
||||
[0.0 1.0 0.0]
|
||||
>>> print L
|
||||
[ 1.0 0.0 0.0]
|
||||
[ 0.0 1.0 0.0]
|
||||
[0.571428571428571 0.214285714285714 1.0]
|
||||
>>> print U
|
||||
[7.0 8.0 9.0]
|
||||
[0.0 2.0 3.0]
|
||||
[0.0 0.0 0.214285714285714]
|
||||
>>> print P.T*L*U
|
||||
[0.0 2.0 3.0]
|
||||
[4.0 5.0 6.0]
|
||||
[7.0 8.0 9.0]
|
||||
|
||||
Interval matrices
|
||||
-----------------
|
||||
|
||||
Matrices may contain interval elements. This allows one to perform
|
||||
basic linear algebra operations such as matrix multiplication
|
||||
and equation solving with rigorous error bounds::
|
||||
|
||||
>>> a = matrix([['0.1','0.3','1.0'],
|
||||
... ['7.1','5.5','4.8'],
|
||||
... ['3.2','4.4','5.6']], force_type=mpi)
|
||||
>>>
|
||||
>>> b = matrix(['4','0.6','0.5'], force_type=mpi)
|
||||
>>> c = lu_solve(a, b)
|
||||
>>> c
|
||||
matrix(
|
||||
[[[5.2582327113062393041, 5.2582327113062749951]],
|
||||
[[-13.155049396267856583, -13.155049396267821167]],
|
||||
[[7.4206915477497212555, 7.4206915477497310922]]])
|
||||
>>> print a*c
|
||||
[ [3.9999999999999866773, 4.0000000000000133227]]
|
||||
[[0.59999999999972430942, 0.60000000000027142733]]
|
||||
[[0.49999999999982236432, 0.50000000000018474111]]
|
||||
"""
|
||||
|
||||
# TODO:
|
||||
# *implement high-level qr()
|
||||
# *test unitvector
|
||||
# *iterative solving
|
||||
|
||||
from copy import copy
|
||||
|
||||
class LinearAlgebraMethods(object):
|
||||
|
||||
def LU_decomp(ctx, A, overwrite=False, use_cache=True):
|
||||
"""
|
||||
LU-factorization of a n*n matrix using the Gauss algorithm.
|
||||
Returns L and U in one matrix and the pivot indices.
|
||||
|
||||
Use overwrite to specify whether A will be overwritten with L and U.
|
||||
"""
|
||||
if not A.rows == A.cols:
|
||||
raise ValueError('need n*n matrix')
|
||||
# get from cache if possible
|
||||
if use_cache and isinstance(A, ctx.matrix) and A._LU:
|
||||
return A._LU
|
||||
if not overwrite:
|
||||
orig = A
|
||||
A = A.copy()
|
||||
tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
|
||||
n = A.rows
|
||||
p = [None]*(n - 1)
|
||||
for j in xrange(n - 1):
|
||||
# pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
|
||||
biggest = 0
|
||||
for k in xrange(j, n):
|
||||
s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
|
||||
if ctx.absmin(s) <= tol:
|
||||
raise ZeroDivisionError('matrix is numerically singular')
|
||||
current = 1/s * ctx.absmin(A[k,j])
|
||||
if current > biggest: # TODO: what if equal?
|
||||
biggest = current
|
||||
p[j] = k
|
||||
# swap rows according to p
|
||||
ctx.swap_row(A, j, p[j])
|
||||
if ctx.absmin(A[j,j]) <= tol:
|
||||
raise ZeroDivisionError('matrix is numerically singular')
|
||||
# calculate elimination factors and add rows
|
||||
for i in xrange(j + 1, n):
|
||||
A[i,j] /= A[j,j]
|
||||
for k in xrange(j + 1, n):
|
||||
A[i,k] -= A[i,j]*A[j,k]
|
||||
if ctx.absmin(A[n - 1,n - 1]) <= tol:
|
||||
raise ZeroDivisionError('matrix is numerically singular')
|
||||
# cache decomposition
|
||||
if not overwrite and isinstance(orig, ctx.matrix):
|
||||
orig._LU = (A, p)
|
||||
return A, p
|
||||
|
||||
def L_solve(ctx, L, b, p=None):
|
||||
"""
|
||||
Solve the lower part of a LU factorized matrix for y.
|
||||
"""
|
||||
assert L.rows == L.cols, 'need n*n matrix'
|
||||
n = L.rows
|
||||
assert len(b) == n
|
||||
b = copy(b)
|
||||
if p: # swap b according to p
|
||||
for k in xrange(0, len(p)):
|
||||
ctx.swap_row(b, k, p[k])
|
||||
# solve
|
||||
for i in xrange(1, n):
|
||||
for j in xrange(i):
|
||||
b[i] -= L[i,j] * b[j]
|
||||
return b
|
||||
|
||||
def U_solve(ctx, U, y):
|
||||
"""
|
||||
Solve the upper part of a LU factorized matrix for x.
|
||||
"""
|
||||
assert U.rows == U.cols, 'need n*n matrix'
|
||||
n = U.rows
|
||||
assert len(y) == n
|
||||
x = copy(y)
|
||||
for i in xrange(n - 1, -1, -1):
|
||||
for j in xrange(i + 1, n):
|
||||
x[i] -= U[i,j] * x[j]
|
||||
x[i] /= U[i,i]
|
||||
return x
|
||||
|
||||
def lu_solve(ctx, A, b, **kwargs):
|
||||
"""
|
||||
Ax = b => x
|
||||
|
||||
Solve a determined or overdetermined linear equations system.
|
||||
Fast LU decomposition is used, which is less accurate than QR decomposition
|
||||
(especially for overdetermined systems), but it's twice as efficient.
|
||||
Use qr_solve if you want more precision or have to solve a very ill-
|
||||
conditioned system.
|
||||
|
||||
If you specify real=True, it does not check for overdeterminded complex
|
||||
systems.
|
||||
"""
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
# do not overwrite A nor b
|
||||
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
||||
if A.rows < A.cols:
|
||||
raise ValueError('cannot solve underdetermined system')
|
||||
if A.rows > A.cols:
|
||||
# use least-squares method if overdetermined
|
||||
# (this increases errors)
|
||||
AH = A.H
|
||||
A = AH * A
|
||||
b = AH * b
|
||||
if (kwargs.get('real', False) or
|
||||
not sum(type(i) is ctx.mpc for i in A)):
|
||||
# TODO: necessary to check also b?
|
||||
x = ctx.cholesky_solve(A, b)
|
||||
else:
|
||||
x = ctx.lu_solve(A, b)
|
||||
else:
|
||||
# LU factorization
|
||||
A, p = ctx.LU_decomp(A)
|
||||
b = ctx.L_solve(A, b, p)
|
||||
x = ctx.U_solve(A, b)
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return x
|
||||
|
||||
def improve_solution(ctx, A, x, b, maxsteps=1):
|
||||
"""
|
||||
Improve a solution to a linear equation system iteratively.
|
||||
|
||||
This re-uses the LU decomposition and is thus cheap.
|
||||
Usually 3 up to 4 iterations are giving the maximal improvement.
|
||||
"""
|
||||
assert A.rows == A.cols, 'need n*n matrix' # TODO: really?
|
||||
for _ in xrange(maxsteps):
|
||||
r = ctx.residual(A, x, b)
|
||||
if ctx.norm(r, 2) < 10*ctx.eps:
|
||||
break
|
||||
# this uses cached LU decomposition and is thus cheap
|
||||
dx = ctx.lu_solve(A, -r)
|
||||
x += dx
|
||||
return x
|
||||
|
||||
def lu(ctx, A):
|
||||
"""
|
||||
A -> P, L, U
|
||||
|
||||
LU factorisation of a square matrix A. L is the lower, U the upper part.
|
||||
P is the permutation matrix indicating the row swaps.
|
||||
|
||||
P*A = L*U
|
||||
|
||||
If you need efficiency, use the low-level method LU_decomp instead, it's
|
||||
much more memory efficient.
|
||||
"""
|
||||
# get factorization
|
||||
A, p = ctx.LU_decomp(A)
|
||||
n = A.rows
|
||||
L = ctx.matrix(n)
|
||||
U = ctx.matrix(n)
|
||||
for i in xrange(n):
|
||||
for j in xrange(n):
|
||||
if i > j:
|
||||
L[i,j] = A[i,j]
|
||||
elif i == j:
|
||||
L[i,j] = 1
|
||||
U[i,j] = A[i,j]
|
||||
else:
|
||||
U[i,j] = A[i,j]
|
||||
# calculate permutation matrix
|
||||
P = ctx.eye(n)
|
||||
for k in xrange(len(p)):
|
||||
ctx.swap_row(P, k, p[k])
|
||||
return P, L, U
|
||||
|
||||
def unitvector(ctx, n, i):
|
||||
"""
|
||||
Return the i-th n-dimensional unit vector.
|
||||
"""
|
||||
assert 0 < i <= n, 'this unit vector does not exist'
|
||||
return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
|
||||
|
||||
def inverse(ctx, A, **kwargs):
|
||||
"""
|
||||
Calculate the inverse of a matrix.
|
||||
|
||||
If you want to solve an equation system Ax = b, it's recommended to use
|
||||
solve(A, b) instead, it's about 3 times more efficient.
|
||||
"""
|
||||
prec = ctx.prec
|
||||
try:
|
||||
ctx.prec += 10
|
||||
# do not overwrite A
|
||||
A = ctx.matrix(A, **kwargs).copy()
|
||||
n = A.rows
|
||||
# get LU factorisation
|
||||
A, p = ctx.LU_decomp(A)
|
||||
cols = []
|
||||
# calculate unit vectors and solve corresponding system to get columns
|
||||
for i in xrange(1, n + 1):
|
||||
e = ctx.unitvector(n, i)
|
||||
y = ctx.L_solve(A, e, p)
|
||||
cols.append(ctx.U_solve(A, y))
|
||||
# convert columns to matrix
|
||||
inv = []
|
||||
for i in xrange(n):
|
||||
row = []
|
||||
for j in xrange(n):
|
||||
row.append(cols[j][i])
|
||||
inv.append(row)
|
||||
result = ctx.matrix(inv, **kwargs)
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
return result
|
||||
|
||||
def householder(ctx, A):
|
||||
"""
|
||||
(A|b) -> H, p, x, res
|
||||
|
||||
(A|b) is the coefficient matrix with left hand side of an optionally
|
||||
overdetermined linear equation system.
|
||||
H and p contain all information about the transformation matrices.
|
||||
x is the solution, res the residual.
|
||||
"""
|
||||
assert isinstance(A, ctx.matrix)
|
||||
m = A.rows
|
||||
n = A.cols
|
||||
assert m >= n - 1
|
||||
# calculate Householder matrix
|
||||
p = []
|
||||
for j in xrange(0, n - 1):
|
||||
s = ctx.fsum((A[i,j])**2 for i in xrange(j, m))
|
||||
if not abs(s) > ctx.eps:
|
||||
raise ValueError('matrix is numerically singular')
|
||||
p.append(-ctx.sign(A[j,j]) * ctx.sqrt(s))
|
||||
kappa = ctx.one / (s - p[j] * A[j,j])
|
||||
A[j,j] -= p[j]
|
||||
for k in xrange(j+1, n):
|
||||
y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j, m)) * kappa
|
||||
for i in xrange(j, m):
|
||||
A[i,k] -= A[i,j] * y
|
||||
# solve Rx = c1
|
||||
x = [A[i,n - 1] for i in xrange(n - 1)]
|
||||
for i in xrange(n - 2, -1, -1):
|
||||
x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
|
||||
x[i] /= p[i]
|
||||
# calculate residual
|
||||
if not m == n - 1:
|
||||
r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
|
||||
else:
|
||||
# determined system, residual should be 0
|
||||
r = [0]*m # maybe a bad idea, changing r[i] will change all elements
|
||||
return A, p, x, r
|
||||
|
||||
#def qr(ctx, A):
|
||||
# """
|
||||
# A -> Q, R
|
||||
#
|
||||
# QR factorisation of a square matrix A using Householder decomposition.
|
||||
# Q is orthogonal, this leads to very few numerical errors.
|
||||
#
|
||||
# A = Q*R
|
||||
# """
|
||||
# H, p, x, res = householder(A)
|
||||
# TODO: implement this
|
||||
|
||||
def residual(ctx, A, x, b, **kwargs):
|
||||
"""
|
||||
Calculate the residual of a solution to a linear equation system.
|
||||
|
||||
r = A*x - b for A*x = b
|
||||
"""
|
||||
oldprec = ctx.prec
|
||||
try:
|
||||
ctx.prec *= 2
|
||||
A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
|
||||
return A*x - b
|
||||
finally:
|
||||
ctx.prec = oldprec
|
||||
|
||||
def qr_solve(ctx, A, b, norm=None, **kwargs):
|
||||
"""
|
||||
Ax = b => x, ||Ax - b||
|
||||
|
||||
Solve a determined or overdetermined linear equations system and
|
||||
calculate the norm of the residual (error).
|
||||
QR decomposition using Householder factorization is applied, which gives very
|
||||
accurate results even for ill-conditioned matrices. qr_solve is twice as
|
||||
efficient.
|
||||
"""
|
||||
if norm is None:
|
||||
norm = ctx.norm
|
||||
prec = ctx.prec
|
||||
try:
|
||||
prec += 10
|
||||
# do not overwrite A nor b
|
||||
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
||||
if A.rows < A.cols:
|
||||
raise ValueError('cannot solve underdetermined system')
|
||||
H, p, x, r = ctx.householder(ctx.extend(A, b))
|
||||
res = ctx.norm(r)
|
||||
# calculate residual "manually" for determined systems
|
||||
if res == 0:
|
||||
res = ctx.norm(ctx.residual(A, x, b))
|
||||
return ctx.matrix(x, **kwargs), res
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
|
||||
# TODO: possible for complex matrices? -> have a look at GSL
|
||||
def cholesky(ctx, A):
|
||||
"""
|
||||
Cholesky decomposition of a symmetric positive-definite matrix.
|
||||
|
||||
Can be used to solve linear equation systems twice as efficient compared
|
||||
to LU decomposition or to test whether A is positive-definite.
|
||||
|
||||
A = L * L.T
|
||||
Only L (the lower part) is returned.
|
||||
"""
|
||||
assert isinstance(A, ctx.matrix)
|
||||
if not A.rows == A.cols:
|
||||
raise ValueError('need n*n matrix')
|
||||
n = A.rows
|
||||
L = ctx.matrix(n)
|
||||
for j in xrange(n):
|
||||
s = A[j,j] - ctx.fsum(L[j,k]**2 for k in xrange(j))
|
||||
if s < ctx.eps:
|
||||
raise ValueError('matrix not positive-definite')
|
||||
L[j,j] = ctx.sqrt(s)
|
||||
for i in xrange(j, n):
|
||||
L[i,j] = (A[i,j] - ctx.fsum(L[i,k] * L[j,k] for k in xrange(j))) \
|
||||
/ L[j,j]
|
||||
return L
|
||||
|
||||
def cholesky_solve(ctx, A, b, **kwargs):
|
||||
"""
|
||||
Ax = b => x
|
||||
|
||||
Solve a symmetric positive-definite linear equation system.
|
||||
This is twice as efficient as lu_solve.
|
||||
|
||||
Typical use cases:
|
||||
* A.T*A
|
||||
* Hessian matrix
|
||||
* differential equations
|
||||
"""
|
||||
prec = ctx.prec
|
||||
try:
|
||||
prec += 10
|
||||
# do not overwrite A nor b
|
||||
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
||||
if A.rows != A.cols:
|
||||
raise ValueError('can only solve determined system')
|
||||
# Cholesky factorization
|
||||
L = ctx.cholesky(A)
|
||||
# solve
|
||||
n = L.rows
|
||||
assert len(b) == n
|
||||
for i in xrange(n):
|
||||
b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
|
||||
b[i] /= L[i,i]
|
||||
x = ctx.U_solve(L.T, b)
|
||||
return x
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
|
||||
def det(ctx, A):
|
||||
"""
|
||||
Calculate the determinant of a matrix.
|
||||
"""
|
||||
prec = ctx.prec
|
||||
try:
|
||||
# do not overwrite A
|
||||
A = ctx.matrix(A).copy()
|
||||
# use LU factorization to calculate determinant
|
||||
try:
|
||||
R, p = ctx.LU_decomp(A)
|
||||
except ZeroDivisionError:
|
||||
return 0
|
||||
z = 1
|
||||
for i, e in enumerate(p):
|
||||
if i != e:
|
||||
z *= -1
|
||||
for i in xrange(A.rows):
|
||||
z *= R[i,i]
|
||||
return z
|
||||
finally:
|
||||
ctx.prec = prec
|
||||
|
||||
def cond(ctx, A, norm=None):
|
||||
"""
|
||||
Calculate the condition number of a matrix using a specified matrix norm.
|
||||
|
||||
The condition number estimates the sensitivity of a matrix to errors.
|
||||
Example: small input errors for ill-conditioned coefficient matrices
|
||||
alter the solution of the system dramatically.
|
||||
|
||||
For ill-conditioned matrices it's recommended to use qr_solve() instead
|
||||
of lu_solve(). This does not help with input errors however, it just avoids
|
||||
to add additional errors.
|
||||
|
||||
Definition: cond(A) = ||A|| * ||A**-1||
|
||||
"""
|
||||
if norm is None:
|
||||
norm = lambda x: ctx.mnorm(x,1)
|
||||
return norm(A) * norm(ctx.inverse(A))
|
||||
|
||||
def lu_solve_mat(ctx, a, b):
|
||||
"""Solve a * x = b where a and b are matrices."""
|
||||
r = ctx.matrix(a.rows, b.cols)
|
||||
for i in range(b.cols):
|
||||
c = ctx.lu_solve(a, b.column(i))
|
||||
for j in range(len(c)):
|
||||
r[j, i] = c[j]
|
||||
return r
|
||||
|
||||
|
|
@ -1,858 +0,0 @@
|
|||
# TODO: interpret list as vectors (for multiplication)
|
||||
|
||||
rowsep = '\n'
|
||||
colsep = ' '
|
||||
|
||||
class _matrix(object):
|
||||
"""
|
||||
Numerical matrix.
|
||||
|
||||
Specify the dimensions or the data as a nested list.
|
||||
Elements default to zero.
|
||||
Use a flat list to create a column vector easily.
|
||||
|
||||
By default, only mpf is used to store the data. You can specify another type
|
||||
using force_type=type. It's possible to specify None.
|
||||
Make sure force_type(force_type()) is fast.
|
||||
|
||||
Creating matrices
|
||||
-----------------
|
||||
|
||||
Matrices in mpmath are implemented using dictionaries. Only non-zero values
|
||||
are stored, so it is cheap to represent sparse matrices.
|
||||
|
||||
The most basic way to create one is to use the ``matrix`` class directly.
|
||||
You can create an empty matrix specifying the dimensions:
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15
|
||||
>>> matrix(2)
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
>>> matrix(2, 3)
|
||||
matrix(
|
||||
[['0.0', '0.0', '0.0'],
|
||||
['0.0', '0.0', '0.0']])
|
||||
|
||||
Calling ``matrix`` with one dimension will create a square matrix.
|
||||
|
||||
To access the dimensions of a matrix, use the ``rows`` or ``cols`` keyword:
|
||||
|
||||
>>> A = matrix(3, 2)
|
||||
>>> A
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
>>> A.rows
|
||||
3
|
||||
>>> A.cols
|
||||
2
|
||||
|
||||
You can also change the dimension of an existing matrix. This will set the
|
||||
new elements to 0. If the new dimension is smaller than before, the
|
||||
concerning elements are discarded:
|
||||
|
||||
>>> A.rows = 2
|
||||
>>> A
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
|
||||
Internally ``mpmathify`` is used every time an element is set. This
|
||||
is done using the syntax A[row,column], counting from 0:
|
||||
|
||||
>>> A = matrix(2)
|
||||
>>> A[1,1] = 1 + 1j
|
||||
>>> A
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '(1.0 + 1.0j)']])
|
||||
|
||||
You can use the keyword ``force_type`` to change the function which is
|
||||
called on every new element:
|
||||
|
||||
>>> matrix(2, 5, force_type=int)
|
||||
matrix(
|
||||
[[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]])
|
||||
|
||||
A more comfortable way to create a matrix lets you use nested lists:
|
||||
|
||||
>>> matrix([[1, 2], [3, 4]])
|
||||
matrix(
|
||||
[['1.0', '2.0'],
|
||||
['3.0', '4.0']])
|
||||
|
||||
If you want to preserve the type of the elements you can use
|
||||
``force_type=None``:
|
||||
|
||||
>>> matrix([[1, 2.5], [1j, mpf(2)]], force_type=None)
|
||||
matrix(
|
||||
[[1, 2.5],
|
||||
[1j, '2.0']])
|
||||
|
||||
Convenient advanced functions are available for creating various standard
|
||||
matrices, see ``zeros``, ``ones``, ``diag``, ``eye``, ``randmatrix`` and
|
||||
``hilbert``.
|
||||
|
||||
Vectors
|
||||
.......
|
||||
|
||||
Vectors may also be represented by the ``matrix`` class (with rows = 1 or cols = 1).
|
||||
For vectors there are some things which make life easier. A column vector can
|
||||
be created using a flat list, a row vectors using an almost flat nested list::
|
||||
|
||||
>>> matrix([1, 2, 3])
|
||||
matrix(
|
||||
[['1.0'],
|
||||
['2.0'],
|
||||
['3.0']])
|
||||
>>> matrix([[1, 2, 3]])
|
||||
matrix(
|
||||
[['1.0', '2.0', '3.0']])
|
||||
|
||||
Optionally vectors can be accessed like lists, using only a single index::
|
||||
|
||||
>>> x = matrix([1, 2, 3])
|
||||
>>> x[1]
|
||||
mpf('2.0')
|
||||
>>> x[1,0]
|
||||
mpf('2.0')
|
||||
|
||||
Other
|
||||
.....
|
||||
|
||||
Like you probably expected, matrices can be printed::
|
||||
|
||||
>>> print randmatrix(3) # doctest:+SKIP
|
||||
[ 0.782963853573023 0.802057689719883 0.427895717335467]
|
||||
[0.0541876859348597 0.708243266653103 0.615134039977379]
|
||||
[ 0.856151514955773 0.544759264818486 0.686210904770947]
|
||||
|
||||
Use ``nstr`` or ``nprint`` to specify the number of digits to print::
|
||||
|
||||
>>> nprint(randmatrix(5), 3) # doctest:+SKIP
|
||||
[2.07e-1 1.66e-1 5.06e-1 1.89e-1 8.29e-1]
|
||||
[6.62e-1 6.55e-1 4.47e-1 4.82e-1 2.06e-2]
|
||||
[4.33e-1 7.75e-1 6.93e-2 2.86e-1 5.71e-1]
|
||||
[1.01e-1 2.53e-1 6.13e-1 3.32e-1 2.59e-1]
|
||||
[1.56e-1 7.27e-2 6.05e-1 6.67e-2 2.79e-1]
|
||||
|
||||
As matrices are mutable, you will need to copy them sometimes::
|
||||
|
||||
>>> A = matrix(2)
|
||||
>>> A
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
>>> B = A.copy()
|
||||
>>> B[0,0] = 1
|
||||
>>> B
|
||||
matrix(
|
||||
[['1.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
>>> A
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
|
||||
Finally, it is possible to convert a matrix to a nested list. This is very useful,
|
||||
as most Python libraries involving matrices or arrays (namely NumPy or SymPy)
|
||||
support this format::
|
||||
|
||||
>>> B.tolist()
|
||||
[[mpf('1.0'), mpf('0.0')], [mpf('0.0'), mpf('0.0')]]
|
||||
|
||||
|
||||
Matrix operations
|
||||
-----------------
|
||||
|
||||
You can add and subtract matrices of compatible dimensions::
|
||||
|
||||
>>> A = matrix([[1, 2], [3, 4]])
|
||||
>>> B = matrix([[-2, 4], [5, 9]])
|
||||
>>> A + B
|
||||
matrix(
|
||||
[['-1.0', '6.0'],
|
||||
['8.0', '13.0']])
|
||||
>>> A - B
|
||||
matrix(
|
||||
[['3.0', '-2.0'],
|
||||
['-2.0', '-5.0']])
|
||||
>>> A + ones(3) # doctest:+ELLIPSIS
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: incompatible dimensions for addition
|
||||
|
||||
It is possible to multiply or add matrices and scalars. In the latter case the
|
||||
operation will be done element-wise::
|
||||
|
||||
>>> A * 2
|
||||
matrix(
|
||||
[['2.0', '4.0'],
|
||||
['6.0', '8.0']])
|
||||
>>> A / 4
|
||||
matrix(
|
||||
[['0.25', '0.5'],
|
||||
['0.75', '1.0']])
|
||||
>>> A - 1
|
||||
matrix(
|
||||
[['0.0', '1.0'],
|
||||
['2.0', '3.0']])
|
||||
|
||||
Of course you can perform matrix multiplication, if the dimensions are
|
||||
compatible::
|
||||
|
||||
>>> A * B
|
||||
matrix(
|
||||
[['8.0', '22.0'],
|
||||
['14.0', '48.0']])
|
||||
>>> matrix([[1, 2, 3]]) * matrix([[-6], [7], [-2]])
|
||||
matrix(
|
||||
[['2.0']])
|
||||
|
||||
You can raise powers of square matrices::
|
||||
|
||||
>>> A**2
|
||||
matrix(
|
||||
[['7.0', '10.0'],
|
||||
['15.0', '22.0']])
|
||||
|
||||
Negative powers will calculate the inverse::
|
||||
|
||||
>>> A**-1
|
||||
matrix(
|
||||
[['-2.0', '1.0'],
|
||||
['1.5', '-0.5']])
|
||||
>>> A * A**-1
|
||||
matrix(
|
||||
[['1.0', '1.0842021724855e-19'],
|
||||
['-2.16840434497101e-19', '1.0']])
|
||||
|
||||
Matrix transposition is straightforward::
|
||||
|
||||
>>> A = ones(2, 3)
|
||||
>>> A
|
||||
matrix(
|
||||
[['1.0', '1.0', '1.0'],
|
||||
['1.0', '1.0', '1.0']])
|
||||
>>> A.T
|
||||
matrix(
|
||||
[['1.0', '1.0'],
|
||||
['1.0', '1.0'],
|
||||
['1.0', '1.0']])
|
||||
|
||||
Norms
|
||||
.....
|
||||
|
||||
Sometimes you need to know how "large" a matrix or vector is. Due to their
|
||||
multidimensional nature it's not possible to compare them, but there are
|
||||
several functions to map a matrix or a vector to a positive real number, the
|
||||
so called norms.
|
||||
|
||||
For vectors the p-norm is intended, usually the 1-, the 2- and the oo-norm are
|
||||
used.
|
||||
|
||||
>>> x = matrix([-10, 2, 100])
|
||||
>>> norm(x, 1)
|
||||
mpf('112.0')
|
||||
>>> norm(x, 2)
|
||||
mpf('100.5186549850325')
|
||||
>>> norm(x, inf)
|
||||
mpf('100.0')
|
||||
|
||||
Please note that the 2-norm is the most used one, though it is more expensive
|
||||
to calculate than the 1- or oo-norm.
|
||||
|
||||
It is possible to generalize some vector norms to matrix norm::
|
||||
|
||||
>>> A = matrix([[1, -1000], [100, 50]])
|
||||
>>> mnorm(A, 1)
|
||||
mpf('1050.0')
|
||||
>>> mnorm(A, inf)
|
||||
mpf('1001.0')
|
||||
>>> mnorm(A, 'F')
|
||||
mpf('1006.2310867787777')
|
||||
|
||||
The last norm (the "Frobenius-norm") is an approximation for the 2-norm, which
|
||||
is hard to calculate and not available. The Frobenius-norm lacks some
|
||||
mathematical properties you might expect from a norm.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__data = {}
|
||||
# LU decompostion cache, this is useful when solving the same system
|
||||
# multiple times, when calculating the inverse and when calculating the
|
||||
# determinant
|
||||
self._LU = None
|
||||
convert = kwargs.get('force_type', self.ctx.convert)
|
||||
if isinstance(args[0], (list, tuple)):
|
||||
if isinstance(args[0][0], (list, tuple)):
|
||||
# interpret nested list as matrix
|
||||
A = args[0]
|
||||
self.__rows = len(A)
|
||||
self.__cols = len(A[0])
|
||||
for i, row in enumerate(A):
|
||||
for j, a in enumerate(row):
|
||||
self[i, j] = convert(a)
|
||||
else:
|
||||
# interpret list as row vector
|
||||
v = args[0]
|
||||
self.__rows = len(v)
|
||||
self.__cols = 1
|
||||
for i, e in enumerate(v):
|
||||
self[i, 0] = e
|
||||
elif isinstance(args[0], int):
|
||||
# create empty matrix of given dimensions
|
||||
if len(args) == 1:
|
||||
self.__rows = self.__cols = args[0]
|
||||
else:
|
||||
assert isinstance(args[1], int), 'expected int'
|
||||
self.__rows = args[0]
|
||||
self.__cols = args[1]
|
||||
elif isinstance(args[0], _matrix):
|
||||
A = args[0].copy()
|
||||
self.__data = A._matrix__data
|
||||
self.__rows = A._matrix__rows
|
||||
self.__cols = A._matrix__cols
|
||||
convert = kwargs.get('force_type', self.ctx.convert)
|
||||
for i in xrange(A.__rows):
|
||||
for j in xrange(A.__cols):
|
||||
A[i,j] = convert(A[i,j])
|
||||
elif hasattr(args[0], 'tolist'):
|
||||
A = self.ctx.matrix(args[0].tolist())
|
||||
self.__data = A._matrix__data
|
||||
self.__rows = A._matrix__rows
|
||||
self.__cols = A._matrix__cols
|
||||
else:
|
||||
raise TypeError('could not interpret given arguments')
|
||||
|
||||
def apply(self, f):
|
||||
"""
|
||||
Return a copy of self with the function `f` applied elementwise.
|
||||
"""
|
||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
new[i,j] = f(self[i,j])
|
||||
return new
|
||||
|
||||
def __nstr__(self, n=None, **kwargs):
|
||||
# Build table of string representations of the elements
|
||||
res = []
|
||||
# Track per-column max lengths for pretty alignment
|
||||
maxlen = [0] * self.cols
|
||||
for i in range(self.rows):
|
||||
res.append([])
|
||||
for j in range(self.cols):
|
||||
if n:
|
||||
string = self.ctx.nstr(self[i,j], n, **kwargs)
|
||||
else:
|
||||
string = str(self[i,j])
|
||||
res[-1].append(string)
|
||||
maxlen[j] = max(len(string), maxlen[j])
|
||||
# Patch strings together
|
||||
for i, row in enumerate(res):
|
||||
for j, elem in enumerate(row):
|
||||
# Pad each element up to maxlen so the columns line up
|
||||
row[j] = elem.rjust(maxlen[j])
|
||||
res[i] = "[" + colsep.join(row) + "]"
|
||||
return rowsep.join(res)
|
||||
|
||||
def __str__(self):
|
||||
return self.__nstr__()
|
||||
|
||||
def _toliststr(self, avoid_type=False):
|
||||
"""
|
||||
Create a list string from a matrix.
|
||||
|
||||
If avoid_type: avoid multiple 'mpf's.
|
||||
"""
|
||||
# XXX: should be something like self.ctx._types
|
||||
typ = self.ctx.mpf
|
||||
s = '['
|
||||
for i in xrange(self.__rows):
|
||||
s += '['
|
||||
for j in xrange(self.__cols):
|
||||
if not avoid_type or not isinstance(self[i,j], typ):
|
||||
a = repr(self[i,j])
|
||||
else:
|
||||
a = "'" + str(self[i,j]) + "'"
|
||||
s += a + ', '
|
||||
s = s[:-2]
|
||||
s += '],\n '
|
||||
s = s[:-3]
|
||||
s += ']'
|
||||
return s
|
||||
|
||||
def tolist(self):
|
||||
"""
|
||||
Convert the matrix to a nested list.
|
||||
"""
|
||||
return [[self[i,j] for j in range(self.__cols)] for i in range(self.__rows)]
|
||||
|
||||
def __repr__(self):
|
||||
if self.ctx.pretty:
|
||||
return self.__str__()
|
||||
s = 'matrix(\n'
|
||||
s += self._toliststr(avoid_type=True) + ')'
|
||||
return s
|
||||
|
||||
def __getitem__(self, key):
|
||||
if type(key) is int:
|
||||
# only sufficent for vectors
|
||||
if self.__rows == 1:
|
||||
key = (0, key)
|
||||
elif self.__cols == 1:
|
||||
key = (key, 0)
|
||||
else:
|
||||
raise IndexError('insufficient indices for matrix')
|
||||
if key in self.__data:
|
||||
return self.__data[key]
|
||||
else:
|
||||
if key[0] >= self.__rows or key[1] >= self.__cols:
|
||||
raise IndexError('matrix index out of range')
|
||||
return self.ctx.zero
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if type(key) is int:
|
||||
# only sufficent for vectors
|
||||
if self.__rows == 1:
|
||||
key = (0, key)
|
||||
elif self.__cols == 1:
|
||||
key = (key, 0)
|
||||
else:
|
||||
raise IndexError('insufficient indices for matrix')
|
||||
if key[0] >= self.__rows or key[1] >= self.__cols:
|
||||
raise IndexError('matrix index out of range')
|
||||
value = self.ctx.convert(value)
|
||||
if value: # only store non-zeros
|
||||
self.__data[key] = value
|
||||
elif key in self.__data:
|
||||
del self.__data[key]
|
||||
# TODO: maybe do this better, if the performance impact is significant
|
||||
if self._LU:
|
||||
self._LU = None
|
||||
|
||||
def __iter__(self):
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
yield self[i,j]
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, self.ctx.matrix):
|
||||
# dot multiplication TODO: use Strassen's method?
|
||||
if self.__cols != other.__rows:
|
||||
raise ValueError('dimensions not compatible for multiplication')
|
||||
new = self.ctx.matrix(self.__rows, other.__cols)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(other.__cols):
|
||||
new[i, j] = self.ctx.fdot((self[i,k], other[k,j])
|
||||
for k in xrange(other.__rows))
|
||||
return new
|
||||
else:
|
||||
# try scalar multiplication
|
||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
new[i, j] = other * self[i, j]
|
||||
return new
|
||||
|
||||
def __rmul__(self, other):
|
||||
# assume other is scalar and thus commutative
|
||||
assert not isinstance(other, self.ctx.matrix)
|
||||
return self.__mul__(other)
|
||||
|
||||
def __pow__(self, other):
|
||||
# avoid cyclic import problems
|
||||
#from linalg import inverse
|
||||
if not isinstance(other, int):
|
||||
raise ValueError('only integer exponents are supported')
|
||||
if not self.__rows == self.__cols:
|
||||
raise ValueError('only powers of square matrices are defined')
|
||||
n = other
|
||||
if n == 0:
|
||||
return self.ctx.eye(self.__rows)
|
||||
if n < 0:
|
||||
n = -n
|
||||
neg = True
|
||||
else:
|
||||
neg = False
|
||||
i = n
|
||||
y = 1
|
||||
z = self.copy()
|
||||
while i != 0:
|
||||
if i % 2 == 1:
|
||||
y = y * z
|
||||
z = z*z
|
||||
i = i // 2
|
||||
if neg:
|
||||
y = self.ctx.inverse(y)
|
||||
return y
|
||||
|
||||
def __div__(self, other):
|
||||
# assume other is scalar and do element-wise divison
|
||||
assert not isinstance(other, self.ctx.matrix)
|
||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
new[i,j] = self[i,j] / other
|
||||
return new
|
||||
|
||||
__truediv__ = __div__
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, self.ctx.matrix):
|
||||
if not (self.__rows == other.__rows and self.__cols == other.__cols):
|
||||
raise ValueError('incompatible dimensions for addition')
|
||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
new[i,j] = self[i,j] + other[i,j]
|
||||
return new
|
||||
else:
|
||||
# assume other is scalar and add element-wise
|
||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
new[i,j] += self[i,j] + other
|
||||
return new
|
||||
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
def __sub__(self, other):
|
||||
if isinstance(other, self.ctx.matrix) and not (self.__rows == other.__rows
|
||||
and self.__cols == other.__cols):
|
||||
raise ValueError('incompatible dimensions for substraction')
|
||||
return self.__add__(other * (-1))
|
||||
|
||||
def __neg__(self):
|
||||
return (-1) * self
|
||||
|
||||
def __rsub__(self, other):
|
||||
return -self + other
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__rows == other.__rows and self.__cols == other.__cols \
|
||||
and self.__data == other.__data
|
||||
|
||||
def __len__(self):
|
||||
if self.rows == 1:
|
||||
return self.cols
|
||||
elif self.cols == 1:
|
||||
return self.rows
|
||||
else:
|
||||
return self.rows # do it like numpy
|
||||
|
||||
def __getrows(self):
|
||||
return self.__rows
|
||||
|
||||
def __setrows(self, value):
|
||||
for key in self.__data.copy().iterkeys():
|
||||
if key[0] >= value:
|
||||
del self.__data[key]
|
||||
self.__rows = value
|
||||
|
||||
rows = property(__getrows, __setrows, doc='number of rows')
|
||||
|
||||
def __getcols(self):
|
||||
return self.__cols
|
||||
|
||||
def __setcols(self, value):
|
||||
for key in self.__data.copy().iterkeys():
|
||||
if key[1] >= value:
|
||||
del self.__data[key]
|
||||
self.__cols = value
|
||||
|
||||
cols = property(__getcols, __setcols, doc='number of columns')
|
||||
|
||||
def transpose(self):
|
||||
new = self.ctx.matrix(self.__cols, self.__rows)
|
||||
for i in xrange(self.__rows):
|
||||
for j in xrange(self.__cols):
|
||||
new[j,i] = self[i,j]
|
||||
return new
|
||||
|
||||
T = property(transpose)
|
||||
|
||||
def conjugate(self):
|
||||
return self.apply(self.ctx.conj)
|
||||
|
||||
def transpose_conj(self):
|
||||
return self.conjugate().transpose()
|
||||
|
||||
H = property(transpose_conj)
|
||||
|
||||
def copy(self):
|
||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
||||
new.__data = self.__data.copy()
|
||||
return new
|
||||
|
||||
__copy__ = copy
|
||||
|
||||
def column(self, n):
|
||||
m = self.ctx.matrix(self.rows, 1)
|
||||
for i in range(self.rows):
|
||||
m[i] = self[i,n]
|
||||
return m
|
||||
|
||||
class MatrixMethods(object):
|
||||
|
||||
def __init__(ctx):
|
||||
# XXX: subclass
|
||||
ctx.matrix = type('matrix', (_matrix,), {})
|
||||
ctx.matrix.ctx = ctx
|
||||
ctx.matrix.convert = ctx.convert
|
||||
|
||||
def eye(ctx, n, **kwargs):
|
||||
"""
|
||||
Create square identity matrix n x n.
|
||||
"""
|
||||
A = ctx.matrix(n, **kwargs)
|
||||
for i in xrange(n):
|
||||
A[i,i] = 1
|
||||
return A
|
||||
|
||||
def diag(ctx, diagonal, **kwargs):
|
||||
"""
|
||||
Create square diagonal matrix using given list.
|
||||
|
||||
Example:
|
||||
>>> from mpmath import diag, mp
|
||||
>>> mp.pretty = False
|
||||
>>> diag([1, 2, 3])
|
||||
matrix(
|
||||
[['1.0', '0.0', '0.0'],
|
||||
['0.0', '2.0', '0.0'],
|
||||
['0.0', '0.0', '3.0']])
|
||||
"""
|
||||
A = ctx.matrix(len(diagonal), **kwargs)
|
||||
for i in xrange(len(diagonal)):
|
||||
A[i,i] = diagonal[i]
|
||||
return A
|
||||
|
||||
def zeros(ctx, *args, **kwargs):
|
||||
"""
|
||||
Create matrix m x n filled with zeros.
|
||||
One given dimension will create square matrix n x n.
|
||||
|
||||
Example:
|
||||
>>> from mpmath import zeros, mp
|
||||
>>> mp.pretty = False
|
||||
>>> zeros(2)
|
||||
matrix(
|
||||
[['0.0', '0.0'],
|
||||
['0.0', '0.0']])
|
||||
"""
|
||||
if len(args) == 1:
|
||||
m = n = args[0]
|
||||
elif len(args) == 2:
|
||||
m = args[0]
|
||||
n = args[1]
|
||||
else:
|
||||
raise TypeError('zeros expected at most 2 arguments, got %i' % len(args))
|
||||
A = ctx.matrix(m, n, **kwargs)
|
||||
for i in xrange(m):
|
||||
for j in xrange(n):
|
||||
A[i,j] = 0
|
||||
return A
|
||||
|
||||
def ones(ctx, *args, **kwargs):
|
||||
"""
|
||||
Create matrix m x n filled with ones.
|
||||
One given dimension will create square matrix n x n.
|
||||
|
||||
Example:
|
||||
>>> from mpmath import ones, mp
|
||||
>>> mp.pretty = False
|
||||
>>> ones(2)
|
||||
matrix(
|
||||
[['1.0', '1.0'],
|
||||
['1.0', '1.0']])
|
||||
"""
|
||||
if len(args) == 1:
|
||||
m = n = args[0]
|
||||
elif len(args) == 2:
|
||||
m = args[0]
|
||||
n = args[1]
|
||||
else:
|
||||
raise TypeError('ones expected at most 2 arguments, got %i' % len(args))
|
||||
A = ctx.matrix(m, n, **kwargs)
|
||||
for i in xrange(m):
|
||||
for j in xrange(n):
|
||||
A[i,j] = 1
|
||||
return A
|
||||
|
||||
def hilbert(ctx, m, n=None):
|
||||
"""
|
||||
Create (pseudo) hilbert matrix m x n.
|
||||
One given dimension will create hilbert matrix n x n.
|
||||
|
||||
The matrix is very ill-conditioned and symmetric, positive definite if
|
||||
square.
|
||||
"""
|
||||
if n is None:
|
||||
n = m
|
||||
A = ctx.matrix(m, n)
|
||||
for i in xrange(m):
|
||||
for j in xrange(n):
|
||||
A[i,j] = ctx.one / (i + j + 1)
|
||||
return A
|
||||
|
||||
def randmatrix(ctx, m, n=None, min=0, max=1, **kwargs):
|
||||
"""
|
||||
Create a random m x n matrix.
|
||||
|
||||
All values are >= min and <max.
|
||||
n defaults to m.
|
||||
|
||||
Example:
|
||||
>>> from mpmath import randmatrix
|
||||
>>> randmatrix(2) # doctest:+SKIP
|
||||
matrix(
|
||||
[['0.53491598236191806', '0.57195669543302752'],
|
||||
['0.85589992269513615', '0.82444367501382143']])
|
||||
"""
|
||||
if not n:
|
||||
n = m
|
||||
A = ctx.matrix(m, n, **kwargs)
|
||||
for i in xrange(m):
|
||||
for j in xrange(n):
|
||||
A[i,j] = ctx.rand() * (max - min) + min
|
||||
return A
|
||||
|
||||
def swap_row(ctx, A, i, j):
|
||||
"""
|
||||
Swap row i with row j.
|
||||
"""
|
||||
if i == j:
|
||||
return
|
||||
if isinstance(A, ctx.matrix):
|
||||
for k in xrange(A.cols):
|
||||
A[i,k], A[j,k] = A[j,k], A[i,k]
|
||||
elif isinstance(A, list):
|
||||
A[i], A[j] = A[j], A[i]
|
||||
else:
|
||||
raise TypeError('could not interpret type')
|
||||
|
||||
def extend(ctx, A, b):
|
||||
"""
|
||||
Extend matrix A with column b and return result.
|
||||
"""
|
||||
assert isinstance(A, ctx.matrix)
|
||||
assert A.rows == len(b)
|
||||
A = A.copy()
|
||||
A.cols += 1
|
||||
for i in xrange(A.rows):
|
||||
A[i, A.cols-1] = b[i]
|
||||
return A
|
||||
|
||||
def norm(ctx, x, p=2):
|
||||
r"""
|
||||
Gives the entrywise `p`-norm of an iterable *x*, i.e. the vector norm
|
||||
`\left(\sum_k |x_k|^p\right)^{1/p}`, for any given `1 \le p \le \infty`.
|
||||
|
||||
Special cases:
|
||||
|
||||
If *x* is not iterable, this just returns ``absmax(x)``.
|
||||
|
||||
``p=1`` gives the sum of absolute values.
|
||||
|
||||
``p=2`` is the standard Euclidean vector norm.
|
||||
|
||||
``p=inf`` gives the magnitude of the largest element.
|
||||
|
||||
For *x* a matrix, ``p=2`` is the Frobenius norm.
|
||||
For operator matrix norms, use :func:`mnorm` instead.
|
||||
|
||||
You can use the string 'inf' as well as float('inf') or mpf('inf')
|
||||
to specify the infinity norm.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> x = matrix([-10, 2, 100])
|
||||
>>> norm(x, 1)
|
||||
mpf('112.0')
|
||||
>>> norm(x, 2)
|
||||
mpf('100.5186549850325')
|
||||
>>> norm(x, inf)
|
||||
mpf('100.0')
|
||||
|
||||
"""
|
||||
try:
|
||||
iter(x)
|
||||
except TypeError:
|
||||
return ctx.absmax(x)
|
||||
if type(p) is not int:
|
||||
p = ctx.convert(p)
|
||||
if p == ctx.inf:
|
||||
return max(ctx.absmax(i) for i in x)
|
||||
elif p == 1:
|
||||
return ctx.fsum(x, absolute=1)
|
||||
elif p == 2:
|
||||
return ctx.sqrt(ctx.fsum(x, absolute=1, squared=1))
|
||||
elif p > 1:
|
||||
return ctx.nthroot(ctx.fsum(abs(i)**p for i in x), p)
|
||||
else:
|
||||
raise ValueError('p has to be >= 1')
|
||||
|
||||
def mnorm(ctx, A, p=1):
|
||||
r"""
|
||||
Gives the matrix (operator) `p`-norm of A. Currently ``p=1`` and ``p=inf``
|
||||
are supported:
|
||||
|
||||
``p=1`` gives the 1-norm (maximal column sum)
|
||||
|
||||
``p=inf`` gives the `\infty`-norm (maximal row sum).
|
||||
You can use the string 'inf' as well as float('inf') or mpf('inf')
|
||||
|
||||
``p=2`` (not implemented) for a square matrix is the usual spectral
|
||||
matrix norm, i.e. the largest singular value.
|
||||
|
||||
``p='f'`` (or 'F', 'fro', 'Frobenius, 'frobenius') gives the
|
||||
Frobenius norm, which is the elementwise 2-norm. The Frobenius norm is an
|
||||
approximation of the spectral norm and satisfies
|
||||
|
||||
.. math ::
|
||||
|
||||
\frac{1}{\sqrt{\mathrm{rank}(A)}} \|A\|_F \le \|A\|_2 \le \|A\|_F
|
||||
|
||||
The Frobenius norm lacks some mathematical properties that might
|
||||
be expected of a norm.
|
||||
|
||||
For general elementwise `p`-norms, use :func:`norm` instead.
|
||||
|
||||
**Examples**
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 15; mp.pretty = False
|
||||
>>> A = matrix([[1, -1000], [100, 50]])
|
||||
>>> mnorm(A, 1)
|
||||
mpf('1050.0')
|
||||
>>> mnorm(A, inf)
|
||||
mpf('1001.0')
|
||||
>>> mnorm(A, 'F')
|
||||
mpf('1006.2310867787777')
|
||||
|
||||
"""
|
||||
A = ctx.matrix(A)
|
||||
if type(p) is not int:
|
||||
if type(p) is str and 'frobenius'.startswith(p.lower()):
|
||||
return ctx.norm(A, 2)
|
||||
p = ctx.convert(p)
|
||||
m, n = A.rows, A.cols
|
||||
if p == 1:
|
||||
return max(ctx.fsum((A[i,j] for i in xrange(m)), absolute=1) for j in xrange(n))
|
||||
elif p == ctx.inf:
|
||||
return max(ctx.fsum((A[i,j] for j in xrange(n)), absolute=1) for i in xrange(m))
|
||||
else:
|
||||
raise NotImplementedError("matrix p-norm for arbitrary p")
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
# TODO: use gmpy.mpq when available?
|
||||
|
||||
class mpq(tuple):
|
||||
"""
|
||||
Rational number type, only intended for internal use.
|
||||
"""
|
||||
|
||||
"""
|
||||
def _mpmath_(self, prec, rounding):
|
||||
# XXX
|
||||
return mp.make_mpf(from_rational(self[0], self[1], prec, rounding))
|
||||
#(mpf(self[0])/self[1])._mpf_
|
||||
|
||||
"""
|
||||
|
||||
def __int__(self):
|
||||
a, b = self
|
||||
return a // b
|
||||
|
||||
def __abs__(self):
|
||||
a, b = self
|
||||
return mpq((abs(a), b))
|
||||
|
||||
def __neg__(self):
|
||||
a, b = self
|
||||
return mpq((-a, b))
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self[0])
|
||||
|
||||
def __cmp__(self, other):
|
||||
if type(other) is int and self[1] == 1:
|
||||
return cmp(self[0], other)
|
||||
return NotImplemented
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, mpq):
|
||||
a, b = self
|
||||
c, d = other
|
||||
return mpq((a*d+b*c, b*d))
|
||||
if isinstance(other, (int, long)):
|
||||
a, b = self
|
||||
return mpq((a+b*other, b))
|
||||
return NotImplemented
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __sub__(self, other):
|
||||
if isinstance(other, mpq):
|
||||
a, b = self
|
||||
c, d = other
|
||||
return mpq((a*d-b*c, b*d))
|
||||
if isinstance(other, (int, long)):
|
||||
a, b = self
|
||||
return mpq((a-b*other, b))
|
||||
return NotImplemented
|
||||
|
||||
def __rsub__(self, other):
|
||||
if isinstance(other, mpq):
|
||||
a, b = self
|
||||
c, d = other
|
||||
return mpq((b*c-a*d, b*d))
|
||||
if isinstance(other, (int, long)):
|
||||
a, b = self
|
||||
return mpq((b*other-a, b))
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, mpq):
|
||||
a, b = self
|
||||
c, d = other
|
||||
return mpq((a*c, b*d))
|
||||
if isinstance(other, (int, long)):
|
||||
a, b = self
|
||||
return mpq((a*other, b))
|
||||
return NotImplemented
|
||||
|
||||
def __div__(self, other):
|
||||
if isinstance(other, (int, long)):
|
||||
if other:
|
||||
a, b = self
|
||||
return mpq((a, b*other))
|
||||
raise ZeroDivisionError
|
||||
return NotImplemented
|
||||
|
||||
def __pow__(self, other):
|
||||
if type(other) is int:
|
||||
a, b = self
|
||||
return mpq((a**other, b**other))
|
||||
return NotImplemented
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
|
||||
mpq_1 = mpq((1,1))
|
||||
mpq_0 = mpq((0,1))
|
||||
mpq_1_2 = mpq((1,2))
|
||||
mpq_3_2 = mpq((3,2))
|
||||
mpq_1_4 = mpq((1,4))
|
||||
mpq_1_16 = mpq((1,16))
|
||||
mpq_3_16 = mpq((3,16))
|
||||
mpq_5_2 = mpq((5,2))
|
||||
mpq_3_4 = mpq((3,4))
|
||||
mpq_7_4 = mpq((7,4))
|
||||
mpq_5_4 = mpq((5,4))
|
||||
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
python runtests.py -py
|
||||
Use py.test to run tests (more useful for debugging)
|
||||
|
||||
python runtests.py -psyco
|
||||
Enable psyco to make tests run about 50% faster
|
||||
|
||||
python runtests.py -coverage
|
||||
Generate test coverage report. Statistics are written to /tmp
|
||||
|
||||
python runtests.py -profile
|
||||
Generate profile stats (this is much slower)
|
||||
|
||||
python runtests.py -nogmpy
|
||||
Run tests without using GMPY even if it exists
|
||||
|
||||
python runtests.py -strict
|
||||
Enforce extra tests in normalize()
|
||||
|
||||
python runtests.py -local
|
||||
Insert '../..' at the beginning of sys.path to use local mpmath
|
||||
|
||||
Additional arguments are used to filter the tests to run. Only files that have
|
||||
one of the arguments in their name are executed.
|
||||
|
||||
"""
|
||||
|
||||
import sys, os, traceback
|
||||
|
||||
if "-psyco" in sys.argv:
|
||||
sys.argv.remove('-psyco')
|
||||
import psyco
|
||||
psyco.full()
|
||||
|
||||
profile = False
|
||||
if "-profile" in sys.argv:
|
||||
sys.argv.remove('-profile')
|
||||
profile = True
|
||||
|
||||
coverage = False
|
||||
if "-coverage" in sys.argv:
|
||||
sys.argv.remove('-coverage')
|
||||
coverage = True
|
||||
|
||||
if "-nogmpy" in sys.argv:
|
||||
sys.argv.remove('-nogmpy')
|
||||
os.environ['MPMATH_NOGMPY'] = 'Y'
|
||||
|
||||
if "-strict" in sys.argv:
|
||||
sys.argv.remove('-strict')
|
||||
os.environ['MPMATH_STRICT'] = 'Y'
|
||||
|
||||
if "-local" in sys.argv:
|
||||
sys.argv.remove('-local')
|
||||
importdir = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]),
|
||||
'../..'))
|
||||
else:
|
||||
importdir = ''
|
||||
|
||||
# TODO: add a flag for this
|
||||
testdir = ''
|
||||
|
||||
def testit(importdir='', testdir=''):
|
||||
"""Run all tests in testdir while importing from importdir."""
|
||||
if importdir:
|
||||
sys.path.insert(1, importdir)
|
||||
if testdir:
|
||||
sys.path.insert(1, testdir)
|
||||
import os.path
|
||||
import mpmath
|
||||
print "mpmath imported from", os.path.dirname(mpmath.__file__)
|
||||
print "mpmath backend:", mpmath.libmp.backend.BACKEND
|
||||
print "mpmath mp class:", repr(mpmath.mp)
|
||||
print "mpmath version:", mpmath.__version__
|
||||
print "Python version:", sys.version
|
||||
print
|
||||
if "-py" in sys.argv:
|
||||
sys.argv.remove('-py')
|
||||
import py
|
||||
py.test.cmdline.main()
|
||||
else:
|
||||
import glob
|
||||
from timeit import default_timer as clock
|
||||
modules = []
|
||||
args = sys.argv[1:]
|
||||
# search for tests in directory of this file if not otherwise specified
|
||||
if not testdir:
|
||||
pattern = os.path.dirname(sys.argv[0])
|
||||
else:
|
||||
pattern = testdir
|
||||
if pattern:
|
||||
pattern += '/'
|
||||
pattern += 'test*.py'
|
||||
# look for tests (respecting specified filter)
|
||||
for f in glob.glob(pattern):
|
||||
name = os.path.splitext(os.path.basename(f))[0]
|
||||
# If run as a script, only run tests given as args, if any are given
|
||||
if args and __name__ == "__main__":
|
||||
ok = False
|
||||
for arg in args:
|
||||
if arg in name:
|
||||
ok = True
|
||||
break
|
||||
if not ok:
|
||||
continue
|
||||
module = __import__(name)
|
||||
priority = module.__dict__.get('priority', 100)
|
||||
if priority == 666:
|
||||
modules = [[priority, name, module]]
|
||||
break
|
||||
modules.append([priority, name, module])
|
||||
# execute tests
|
||||
modules.sort()
|
||||
tstart = clock()
|
||||
for priority, name, module in modules:
|
||||
print name
|
||||
for f in sorted(module.__dict__.keys()):
|
||||
if f.startswith('test_'):
|
||||
if coverage and ('numpy' in f):
|
||||
continue
|
||||
print " ", f[5:].ljust(25),
|
||||
t1 = clock()
|
||||
try:
|
||||
module.__dict__[f]()
|
||||
except:
|
||||
etype, evalue, trb = sys.exc_info()
|
||||
if etype in (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
print
|
||||
print "TEST FAILED!"
|
||||
print
|
||||
traceback.print_exc()
|
||||
t2 = clock()
|
||||
print "ok", " ", ("%.7f" % (t2-t1)), "s"
|
||||
tend = clock()
|
||||
print
|
||||
print "finished tests in", ("%.2f" % (tend-tstart)), "seconds"
|
||||
# clean sys.path
|
||||
if importdir:
|
||||
sys.path.remove(importdir)
|
||||
if testdir:
|
||||
sys.path.remove(testdir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if profile:
|
||||
import cProfile
|
||||
cProfile.run("testit('%s', '%s')" % (importdir, testdir), sort=1)
|
||||
elif coverage:
|
||||
import trace
|
||||
tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix],
|
||||
trace=0, count=1)
|
||||
tracer.run('testit(importdir, testdir)')
|
||||
r = tracer.results()
|
||||
r.write_results(show_missing=True, summary=True, coverdir="/tmp")
|
||||
else:
|
||||
testit(importdir, testdir)
|
||||
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_type_compare():
|
||||
assert mpf(2) == mpc(2,0)
|
||||
assert mpf(0) == mpc(0)
|
||||
assert mpf(2) != mpc(2, 0.00001)
|
||||
assert mpf(2) == 2.0
|
||||
assert mpf(2) != 3.0
|
||||
assert mpf(2) == 2
|
||||
assert mpf(2) != '2.0'
|
||||
assert mpc(2) != '2.0'
|
||||
|
||||
def test_add():
|
||||
assert mpf(2.5) + mpf(3) == 5.5
|
||||
assert mpf(2.5) + 3 == 5.5
|
||||
assert mpf(2.5) + 3.0 == 5.5
|
||||
assert 3 + mpf(2.5) == 5.5
|
||||
assert 3.0 + mpf(2.5) == 5.5
|
||||
assert (3+0j) + mpf(2.5) == 5.5
|
||||
assert mpc(2.5) + mpf(3) == 5.5
|
||||
assert mpc(2.5) + 3 == 5.5
|
||||
assert mpc(2.5) + 3.0 == 5.5
|
||||
assert mpc(2.5) + (3+0j) == 5.5
|
||||
assert 3 + mpc(2.5) == 5.5
|
||||
assert 3.0 + mpc(2.5) == 5.5
|
||||
assert (3+0j) + mpc(2.5) == 5.5
|
||||
|
||||
def test_sub():
|
||||
assert mpf(2.5) - mpf(3) == -0.5
|
||||
assert mpf(2.5) - 3 == -0.5
|
||||
assert mpf(2.5) - 3.0 == -0.5
|
||||
assert 3 - mpf(2.5) == 0.5
|
||||
assert 3.0 - mpf(2.5) == 0.5
|
||||
assert (3+0j) - mpf(2.5) == 0.5
|
||||
assert mpc(2.5) - mpf(3) == -0.5
|
||||
assert mpc(2.5) - 3 == -0.5
|
||||
assert mpc(2.5) - 3.0 == -0.5
|
||||
assert mpc(2.5) - (3+0j) == -0.5
|
||||
assert 3 - mpc(2.5) == 0.5
|
||||
assert 3.0 - mpc(2.5) == 0.5
|
||||
assert (3+0j) - mpc(2.5) == 0.5
|
||||
|
||||
def test_mul():
|
||||
assert mpf(2.5) * mpf(3) == 7.5
|
||||
assert mpf(2.5) * 3 == 7.5
|
||||
assert mpf(2.5) * 3.0 == 7.5
|
||||
assert 3 * mpf(2.5) == 7.5
|
||||
assert 3.0 * mpf(2.5) == 7.5
|
||||
assert (3+0j) * mpf(2.5) == 7.5
|
||||
assert mpc(2.5) * mpf(3) == 7.5
|
||||
assert mpc(2.5) * 3 == 7.5
|
||||
assert mpc(2.5) * 3.0 == 7.5
|
||||
assert mpc(2.5) * (3+0j) == 7.5
|
||||
assert 3 * mpc(2.5) == 7.5
|
||||
assert 3.0 * mpc(2.5) == 7.5
|
||||
assert (3+0j) * mpc(2.5) == 7.5
|
||||
|
||||
def test_div():
|
||||
assert mpf(6) / mpf(3) == 2.0
|
||||
assert mpf(6) / 3 == 2.0
|
||||
assert mpf(6) / 3.0 == 2.0
|
||||
assert 6 / mpf(3) == 2.0
|
||||
assert 6.0 / mpf(3) == 2.0
|
||||
assert (6+0j) / mpf(3.0) == 2.0
|
||||
assert mpc(6) / mpf(3) == 2.0
|
||||
assert mpc(6) / 3 == 2.0
|
||||
assert mpc(6) / 3.0 == 2.0
|
||||
assert mpc(6) / (3+0j) == 2.0
|
||||
assert 6 / mpc(3) == 2.0
|
||||
assert 6.0 / mpc(3) == 2.0
|
||||
assert (6+0j) / mpc(3) == 2.0
|
||||
|
||||
def test_pow():
|
||||
assert mpf(6) ** mpf(3) == 216.0
|
||||
assert mpf(6) ** 3 == 216.0
|
||||
assert mpf(6) ** 3.0 == 216.0
|
||||
assert 6 ** mpf(3) == 216.0
|
||||
assert 6.0 ** mpf(3) == 216.0
|
||||
assert (6+0j) ** mpf(3.0) == 216.0
|
||||
assert mpc(6) ** mpf(3) == 216.0
|
||||
assert mpc(6) ** 3 == 216.0
|
||||
assert mpc(6) ** 3.0 == 216.0
|
||||
assert mpc(6) ** (3+0j) == 216.0
|
||||
assert 6 ** mpc(3) == 216.0
|
||||
assert 6.0 ** mpc(3) == 216.0
|
||||
assert (6+0j) ** mpc(3) == 216.0
|
||||
|
||||
def test_mixed_misc():
|
||||
assert 1 + mpf(3) == mpf(3) + 1 == 4
|
||||
assert 1 - mpf(3) == -(mpf(3) - 1) == -2
|
||||
assert 3 * mpf(2) == mpf(2) * 3 == 6
|
||||
assert 6 / mpf(2) == mpf(6) / 2 == 3
|
||||
assert 1.0 + mpf(3) == mpf(3) + 1.0 == 4
|
||||
assert 1.0 - mpf(3) == -(mpf(3) - 1.0) == -2
|
||||
assert 3.0 * mpf(2) == mpf(2) * 3.0 == 6
|
||||
assert 6.0 / mpf(2) == mpf(6) / 2.0 == 3
|
||||
|
||||
def test_add_misc():
|
||||
mp.dps = 15
|
||||
assert mpf(4) + mpf(-70) == -66
|
||||
assert mpf(1) + mpf(1.1)/80 == 1 + 1.1/80
|
||||
assert mpf((1, 10000000000)) + mpf(3) == mpf((1, 10000000000))
|
||||
assert mpf(3) + mpf((1, 10000000000)) == mpf((1, 10000000000))
|
||||
assert mpf((1, -10000000000)) + mpf(3) == mpf(3)
|
||||
assert mpf(3) + mpf((1, -10000000000)) == mpf(3)
|
||||
assert mpf(1) + 1e-15 != 1
|
||||
assert mpf(1) + 1e-20 == 1
|
||||
assert mpf(1.07e-22) + 0 == mpf(1.07e-22)
|
||||
assert mpf(0) + mpf(1.07e-22) == mpf(1.07e-22)
|
||||
|
||||
def test_complex_misc():
|
||||
# many more tests needed
|
||||
assert 1 + mpc(2) == 3
|
||||
assert not mpc(2).ae(2 + 1e-13)
|
||||
assert mpc(2+1e-15j).ae(2)
|
||||
|
||||
def test_complex_zeros():
|
||||
for a in [0,2]:
|
||||
for b in [0,3]:
|
||||
for c in [0,4]:
|
||||
for d in [0,5]:
|
||||
assert mpc(a,b)*mpc(c,d) == complex(a,b)*complex(c,d)
|
||||
|
||||
def test_hash():
|
||||
for i in range(-256, 256):
|
||||
assert hash(mpf(i)) == hash(i)
|
||||
assert hash(mpf(0.5)) == hash(0.5)
|
||||
assert hash(mpc(2,3)) == hash(2+3j)
|
||||
# Check that this doesn't fail
|
||||
assert hash(inf)
|
||||
# Check that overflow doesn't assign equal hashes to large numbers
|
||||
assert hash(mpf('1e1000')) != hash('1e10000')
|
||||
assert hash(mpc(100,'1e1000')) != hash(mpc(200,'1e1000'))
|
||||
|
||||
def test_arithmetic_functions():
|
||||
import operator
|
||||
ops = [(operator.add, fadd), (operator.sub, fsub), (operator.mul, fmul),
|
||||
(operator.div, fdiv)]
|
||||
a = mpf(0.27)
|
||||
b = mpf(1.13)
|
||||
c = mpc(0.51+2.16j)
|
||||
d = mpc(1.08-0.99j)
|
||||
for x in [a,b,c,d]:
|
||||
for y in [a,b,c,d]:
|
||||
for op, fop in ops:
|
||||
if fop is not fdiv:
|
||||
mp.prec = 200
|
||||
z0 = op(x,y)
|
||||
mp.prec = 60
|
||||
z1 = op(x,y)
|
||||
mp.prec = 53
|
||||
z2 = op(x,y)
|
||||
assert fop(x, y, prec=60) == z1
|
||||
assert fop(x, y) == z2
|
||||
if fop is not fdiv:
|
||||
assert fop(x, y, prec=inf) == z0
|
||||
assert fop(x, y, dps=inf) == z0
|
||||
assert fop(x, y, exact=True) == z0
|
||||
assert fneg(fneg(z1, exact=True), prec=inf) == z1
|
||||
assert fneg(z1) == -(+z1)
|
||||
mp.dps = 15
|
||||
|
|
@ -1,172 +0,0 @@
|
|||
"""
|
||||
Test bit-level integer and mpf operations
|
||||
"""
|
||||
|
||||
from mpmath import *
|
||||
from mpmath.libmp import *
|
||||
|
||||
def test_bitcount():
|
||||
assert bitcount(0) == 0
|
||||
assert bitcount(1) == 1
|
||||
assert bitcount(7) == 3
|
||||
assert bitcount(8) == 4
|
||||
assert bitcount(2**100) == 101
|
||||
assert bitcount(2**100-1) == 100
|
||||
|
||||
def test_trailing():
|
||||
assert trailing(0) == 0
|
||||
assert trailing(1) == 0
|
||||
assert trailing(2) == 1
|
||||
assert trailing(7) == 0
|
||||
assert trailing(8) == 3
|
||||
assert trailing(2**100) == 100
|
||||
assert trailing(2**100-1) == 0
|
||||
|
||||
def test_round_down():
|
||||
assert from_man_exp(0, -4, 4, round_down)[:3] == (0, 0, 0)
|
||||
assert from_man_exp(0xf0, -4, 4, round_down)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xf1, -4, 4, round_down)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xff, -4, 4, round_down)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(-0xf0, -4, 4, round_down)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xf1, -4, 4, round_down)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xff, -4, 4, round_down)[:3] == (1, 15, 0)
|
||||
|
||||
def test_round_up():
|
||||
assert from_man_exp(0, -4, 4, round_up)[:3] == (0, 0, 0)
|
||||
assert from_man_exp(0xf0, -4, 4, round_up)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xf1, -4, 4, round_up)[:3] == (0, 1, 4)
|
||||
assert from_man_exp(0xff, -4, 4, round_up)[:3] == (0, 1, 4)
|
||||
assert from_man_exp(-0xf0, -4, 4, round_up)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xf1, -4, 4, round_up)[:3] == (1, 1, 4)
|
||||
assert from_man_exp(-0xff, -4, 4, round_up)[:3] == (1, 1, 4)
|
||||
|
||||
def test_round_floor():
|
||||
assert from_man_exp(0, -4, 4, round_floor)[:3] == (0, 0, 0)
|
||||
assert from_man_exp(0xf0, -4, 4, round_floor)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xf1, -4, 4, round_floor)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xff, -4, 4, round_floor)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(-0xf0, -4, 4, round_floor)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xf1, -4, 4, round_floor)[:3] == (1, 1, 4)
|
||||
assert from_man_exp(-0xff, -4, 4, round_floor)[:3] == (1, 1, 4)
|
||||
|
||||
def test_round_ceiling():
|
||||
assert from_man_exp(0, -4, 4, round_ceiling)[:3] == (0, 0, 0)
|
||||
assert from_man_exp(0xf0, -4, 4, round_ceiling)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xf1, -4, 4, round_ceiling)[:3] == (0, 1, 4)
|
||||
assert from_man_exp(0xff, -4, 4, round_ceiling)[:3] == (0, 1, 4)
|
||||
assert from_man_exp(-0xf0, -4, 4, round_ceiling)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xf1, -4, 4, round_ceiling)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xff, -4, 4, round_ceiling)[:3] == (1, 15, 0)
|
||||
|
||||
def test_round_nearest():
|
||||
assert from_man_exp(0, -4, 4, round_nearest)[:3] == (0, 0, 0)
|
||||
assert from_man_exp(0xf0, -4, 4, round_nearest)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xf7, -4, 4, round_nearest)[:3] == (0, 15, 0)
|
||||
assert from_man_exp(0xf8, -4, 4, round_nearest)[:3] == (0, 1, 4) # 1111.1000 -> 10000.0
|
||||
assert from_man_exp(0xf9, -4, 4, round_nearest)[:3] == (0, 1, 4) # 1111.1001 -> 10000.0
|
||||
assert from_man_exp(0xe8, -4, 4, round_nearest)[:3] == (0, 7, 1) # 1110.1000 -> 1110.0
|
||||
assert from_man_exp(0xe9, -4, 4, round_nearest)[:3] == (0, 15, 0) # 1110.1001 -> 1111.0
|
||||
assert from_man_exp(-0xf0, -4, 4, round_nearest)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xf7, -4, 4, round_nearest)[:3] == (1, 15, 0)
|
||||
assert from_man_exp(-0xf8, -4, 4, round_nearest)[:3] == (1, 1, 4)
|
||||
assert from_man_exp(-0xf9, -4, 4, round_nearest)[:3] == (1, 1, 4)
|
||||
assert from_man_exp(-0xe8, -4, 4, round_nearest)[:3] == (1, 7, 1)
|
||||
assert from_man_exp(-0xe9, -4, 4, round_nearest)[:3] == (1, 15, 0)
|
||||
|
||||
def test_rounding_bugs():
|
||||
# 1 less than power-of-two cases
|
||||
assert from_man_exp(72057594037927935, -56, 53, round_up) == (0, 1, 0, 1)
|
||||
assert from_man_exp(73786976294838205979l, -65, 53, round_nearest) == (0, 1, 1, 1)
|
||||
assert from_man_exp(31, 0, 4, round_up) == (0, 1, 5, 1)
|
||||
assert from_man_exp(-31, 0, 4, round_floor) == (1, 1, 5, 1)
|
||||
assert from_man_exp(255, 0, 7, round_up) == (0, 1, 8, 1)
|
||||
assert from_man_exp(-255, 0, 7, round_floor) == (1, 1, 8, 1)
|
||||
|
||||
def test_rounding_issue160():
|
||||
a = from_man_exp(9867,-100)
|
||||
b = from_man_exp(9867,-200)
|
||||
c = from_man_exp(-1,0)
|
||||
z = (1, 1023, -10, 10)
|
||||
assert mpf_add(a, c, 10, 'd') == z
|
||||
assert mpf_add(b, c, 10, 'd') == z
|
||||
assert mpf_add(c, a, 10, 'd') == z
|
||||
assert mpf_add(c, b, 10, 'd') == z
|
||||
|
||||
def test_perturb():
|
||||
a = fone
|
||||
b = from_float(0.99999999999999989)
|
||||
c = from_float(1.0000000000000002)
|
||||
assert mpf_perturb(a, 0, 53, round_nearest) == a
|
||||
assert mpf_perturb(a, 1, 53, round_nearest) == a
|
||||
assert mpf_perturb(a, 0, 53, round_up) == c
|
||||
assert mpf_perturb(a, 0, 53, round_ceiling) == c
|
||||
assert mpf_perturb(a, 0, 53, round_down) == a
|
||||
assert mpf_perturb(a, 0, 53, round_floor) == a
|
||||
assert mpf_perturb(a, 1, 53, round_up) == a
|
||||
assert mpf_perturb(a, 1, 53, round_ceiling) == a
|
||||
assert mpf_perturb(a, 1, 53, round_down) == b
|
||||
assert mpf_perturb(a, 1, 53, round_floor) == b
|
||||
a = mpf_neg(a)
|
||||
b = mpf_neg(b)
|
||||
c = mpf_neg(c)
|
||||
assert mpf_perturb(a, 0, 53, round_nearest) == a
|
||||
assert mpf_perturb(a, 1, 53, round_nearest) == a
|
||||
assert mpf_perturb(a, 0, 53, round_up) == a
|
||||
assert mpf_perturb(a, 0, 53, round_floor) == a
|
||||
assert mpf_perturb(a, 0, 53, round_down) == b
|
||||
assert mpf_perturb(a, 0, 53, round_ceiling) == b
|
||||
assert mpf_perturb(a, 1, 53, round_up) == c
|
||||
assert mpf_perturb(a, 1, 53, round_floor) == c
|
||||
assert mpf_perturb(a, 1, 53, round_down) == a
|
||||
assert mpf_perturb(a, 1, 53, round_ceiling) == a
|
||||
|
||||
def test_add_exact():
|
||||
ff = from_float
|
||||
assert mpf_add(ff(3.0), ff(2.5)) == ff(5.5)
|
||||
assert mpf_add(ff(3.0), ff(-2.5)) == ff(0.5)
|
||||
assert mpf_add(ff(-3.0), ff(2.5)) == ff(-0.5)
|
||||
assert mpf_add(ff(-3.0), ff(-2.5)) == ff(-5.5)
|
||||
assert mpf_sub(mpf_add(fone, ff(1e-100)), fone) == ff(1e-100)
|
||||
assert mpf_sub(mpf_add(ff(1e-100), fone), fone) == ff(1e-100)
|
||||
assert mpf_sub(mpf_add(fone, ff(-1e-100)), fone) == ff(-1e-100)
|
||||
assert mpf_sub(mpf_add(ff(-1e-100), fone), fone) == ff(-1e-100)
|
||||
assert mpf_add(fone, fzero) == fone
|
||||
assert mpf_add(fzero, fone) == fone
|
||||
assert mpf_add(fzero, fzero) == fzero
|
||||
|
||||
def test_long_exponent_shifts():
|
||||
mp.dps = 15
|
||||
# Check for possible bugs due to exponent arithmetic overflow
|
||||
# in a C implementation
|
||||
x = mpf(1)
|
||||
for p in [32, 64]:
|
||||
a = ldexp(1,2**(p-1))
|
||||
b = ldexp(1,2**p)
|
||||
c = ldexp(1,2**(p+1))
|
||||
d = ldexp(1,-2**(p-1))
|
||||
e = ldexp(1,-2**p)
|
||||
f = ldexp(1,-2**(p+1))
|
||||
assert (x+a) == a
|
||||
assert (x+b) == b
|
||||
assert (x+c) == c
|
||||
assert (x+d) == x
|
||||
assert (x+e) == x
|
||||
assert (x+f) == x
|
||||
assert (a+x) == a
|
||||
assert (b+x) == b
|
||||
assert (c+x) == c
|
||||
assert (d+x) == x
|
||||
assert (e+x) == x
|
||||
assert (f+x) == x
|
||||
assert (x-a) == -a
|
||||
assert (x-b) == -b
|
||||
assert (x-c) == -c
|
||||
assert (x-d) == x
|
||||
assert (x-e) == x
|
||||
assert (x-f) == x
|
||||
assert (a-x) == a
|
||||
assert (b-x) == b
|
||||
assert (c-x) == c
|
||||
assert (d-x) == -x
|
||||
assert (e-x) == -x
|
||||
assert (f-x) == -x
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_approximation():
|
||||
mp.dps = 15
|
||||
f = lambda x: cos(2-2*x)/x
|
||||
p, err = chebyfit(f, [2, 4], 8, error=True)
|
||||
assert err < 1e-5
|
||||
for i in range(10):
|
||||
x = 2 + i/5.
|
||||
assert abs(polyval(p, x) - f(x)) < err
|
||||
|
||||
def test_limits():
|
||||
mp.dps = 15
|
||||
assert limit(lambda x: (x-sin(x))/x**3, 0).ae(mpf(1)/6)
|
||||
assert limit(lambda n: (1+1/n)**n, inf).ae(e)
|
||||
|
||||
def test_polyval():
|
||||
assert polyval([], 3) == 0
|
||||
assert polyval([0], 3) == 0
|
||||
assert polyval([5], 3) == 5
|
||||
# 4x^3 - 2x + 5
|
||||
p = [4, 0, -2, 5]
|
||||
assert polyval(p,4) == 253
|
||||
assert polyval(p,4,derivative=True) == (253, 190)
|
||||
|
||||
def test_polyroots():
|
||||
p = polyroots([1,-4])
|
||||
assert p[0].ae(4)
|
||||
p, q = polyroots([1,2,3])
|
||||
assert p.ae(-1 - sqrt(2)*j)
|
||||
assert q.ae(-1 + sqrt(2)*j)
|
||||
#this is not a real test, it only tests a specific case
|
||||
assert polyroots([1]) == []
|
||||
try:
|
||||
polyroots([0])
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_pade():
|
||||
one = mpf(1)
|
||||
mp.dps = 20
|
||||
N = 10
|
||||
a = [one]
|
||||
k = 1
|
||||
for i in range(1, N+1):
|
||||
k *= i
|
||||
a.append(one/k)
|
||||
p, q = pade(a, N//2, N//2)
|
||||
for x in arange(0, 1, 0.1):
|
||||
r = polyval(p[::-1], x)/polyval(q[::-1], x)
|
||||
assert(r.ae(exp(x), 1.0e-10))
|
||||
mp.dps = 15
|
||||
|
||||
def test_fourier():
|
||||
mp.dps = 15
|
||||
c, s = fourier(lambda x: x+1, [-1, 2], 2)
|
||||
#plot([lambda x: x+1, lambda x: fourierval((c, s), [-1, 2], x)], [-1, 2])
|
||||
assert c[0].ae(1.5)
|
||||
assert c[1].ae(-3*sqrt(3)/(2*pi))
|
||||
assert c[2].ae(3*sqrt(3)/(4*pi))
|
||||
assert s[0] == 0
|
||||
assert s[1].ae(3/(2*pi))
|
||||
assert s[2].ae(3/(4*pi))
|
||||
assert fourierval((c, s), [-1, 2], 1).ae(1.9134966715663442)
|
||||
|
||||
def test_differint():
|
||||
mp.dps = 15
|
||||
assert differint(lambda t: t, 2, -0.5).ae(8*sqrt(2/pi)/3)
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
from mpmath import *
|
||||
from random import seed, randint, random
|
||||
import math
|
||||
|
||||
# Test compatibility with Python floats, which are
|
||||
# IEEE doubles (53-bit)
|
||||
|
||||
N = 5000
|
||||
seed(1)
|
||||
|
||||
# Choosing exponents between roughly -140, 140 ensures that
|
||||
# the Python floats don't overflow or underflow
|
||||
xs = [(random()-1) * 10**randint(-140, 140) for x in range(N)]
|
||||
ys = [(random()-1) * 10**randint(-140, 140) for x in range(N)]
|
||||
|
||||
# include some equal values
|
||||
ys[int(N*0.8):] = xs[int(N*0.8):]
|
||||
|
||||
# Detect whether Python is compiled to use 80-bit floating-point
|
||||
# instructions, in which case the double compatibility test breaks
|
||||
uses_x87 = -4.1974624032366689e+117 / -8.4657370748010221e-47 \
|
||||
== 4.9581771393902231e+163
|
||||
|
||||
def test_double_compatibility():
|
||||
mp.prec = 53
|
||||
for x, y in zip(xs, ys):
|
||||
mpx = mpf(x)
|
||||
mpy = mpf(y)
|
||||
assert mpf(x) == x
|
||||
assert (mpx < mpy) == (x < y)
|
||||
assert (mpx > mpy) == (x > y)
|
||||
assert (mpx == mpy) == (x == y)
|
||||
assert (mpx != mpy) == (x != y)
|
||||
assert (mpx <= mpy) == (x <= y)
|
||||
assert (mpx >= mpy) == (x >= y)
|
||||
assert mpx == mpx
|
||||
if uses_x87:
|
||||
mp.prec = 64
|
||||
a = mpx + mpy
|
||||
b = mpx * mpy
|
||||
c = mpx / mpy
|
||||
d = mpx % mpy
|
||||
mp.prec = 53
|
||||
assert +a == x + y
|
||||
assert +b == x * y
|
||||
assert +c == x / y
|
||||
assert +d == x % y
|
||||
else:
|
||||
assert mpx + mpy == x + y
|
||||
assert mpx * mpy == x * y
|
||||
assert mpx / mpy == x / y
|
||||
assert mpx % mpy == x % y
|
||||
assert abs(mpx) == abs(x)
|
||||
assert mpf(repr(x)) == x
|
||||
assert ceil(mpx) == math.ceil(x)
|
||||
assert floor(mpx) == math.floor(x)
|
||||
|
||||
def test_sqrt():
|
||||
# this fails quite often. it appers to be float
|
||||
# that rounds the wrong way, not mpf
|
||||
fail = 0
|
||||
mp.prec = 53
|
||||
for x in xs:
|
||||
x = abs(x)
|
||||
mp.prec = 100
|
||||
mp_high = mpf(x)**0.5
|
||||
mp.prec = 53
|
||||
mp_low = mpf(x)**0.5
|
||||
fp = x**0.5
|
||||
assert abs(mp_low-mp_high) <= abs(fp-mp_high)
|
||||
fail += mp_low != fp
|
||||
assert fail < N/10
|
||||
|
||||
def test_bugs():
|
||||
# particular bugs
|
||||
assert mpf(4.4408920985006262E-16) < mpf(1.7763568394002505E-15)
|
||||
assert mpf(-4.4408920985006262E-16) > mpf(-1.7763568394002505E-15)
|
||||
|
|
@ -1,186 +0,0 @@
|
|||
import random
|
||||
from mpmath import *
|
||||
from mpmath.libmp import *
|
||||
|
||||
|
||||
def test_basic_string():
|
||||
"""
|
||||
Test basic string conversion
|
||||
"""
|
||||
mp.dps = 15
|
||||
assert mpf('3') == mpf('3.0') == mpf('0003.') == mpf('0.03e2') == mpf(3.0)
|
||||
assert mpf('30') == mpf('30.0') == mpf('00030.') == mpf(30.0)
|
||||
for i in range(10):
|
||||
for j in range(10):
|
||||
assert mpf('%ie%i' % (i,j)) == i * 10**j
|
||||
assert str(mpf('25000.0')) == '25000.0'
|
||||
assert str(mpf('2500.0')) == '2500.0'
|
||||
assert str(mpf('250.0')) == '250.0'
|
||||
assert str(mpf('25.0')) == '25.0'
|
||||
assert str(mpf('2.5')) == '2.5'
|
||||
assert str(mpf('0.25')) == '0.25'
|
||||
assert str(mpf('0.025')) == '0.025'
|
||||
assert str(mpf('0.0025')) == '0.0025'
|
||||
assert str(mpf('0.00025')) == '0.00025'
|
||||
assert str(mpf('0.000025')) == '2.5e-5'
|
||||
assert str(mpf(0)) == '0.0'
|
||||
assert str(mpf('2.5e1000000000000000000000')) == '2.5e+1000000000000000000000'
|
||||
assert str(mpf('2.6e-1000000000000000000000')) == '2.6e-1000000000000000000000'
|
||||
assert str(mpf(1.23402834e-15)) == '1.23402834e-15'
|
||||
assert str(mpf(-1.23402834e-15)) == '-1.23402834e-15'
|
||||
assert str(mpf(-1.2344e-15)) == '-1.2344e-15'
|
||||
assert repr(mpf(-1.2344e-15)) == "mpf('-1.2343999999999999e-15')"
|
||||
|
||||
def test_pretty():
|
||||
mp.pretty = True
|
||||
assert repr(mpf(2.5)) == '2.5'
|
||||
assert repr(mpc(2.5,3.5)) == '(2.5 + 3.5j)'
|
||||
assert repr(mpi(2.5,3.5)) == '[2.5, 3.5]'
|
||||
mp.pretty = False
|
||||
|
||||
def test_str_whitespace():
|
||||
assert mpf('1.26 ') == 1.26
|
||||
|
||||
def test_unicode():
|
||||
mp.dps = 15
|
||||
assert mpf(u'2.76') == 2.76
|
||||
assert mpf(u'inf') == inf
|
||||
|
||||
def test_str_format():
|
||||
assert to_str(from_float(0.1),15,strip_zeros=False) == '0.100000000000000'
|
||||
assert to_str(from_float(0.0),15,show_zero_exponent=True) == '0.0e+0'
|
||||
assert to_str(from_float(0.0),0,show_zero_exponent=True) == '.0e+0'
|
||||
assert to_str(from_float(0.0),0,show_zero_exponent=False) == '.0'
|
||||
assert to_str(from_float(0.0),1,show_zero_exponent=True) == '0.0e+0'
|
||||
assert to_str(from_float(0.0),1,show_zero_exponent=False) == '0.0'
|
||||
assert to_str(from_float(1.23),3,show_zero_exponent=True) == '1.23e+0'
|
||||
assert to_str(from_float(1.23456789000000e-2),15,strip_zeros=False,min_fixed=0,max_fixed=0) == '1.23456789000000e-2'
|
||||
assert to_str(from_float(1.23456789000000e+2),15,strip_zeros=False,min_fixed=0,max_fixed=0) == '1.23456789000000e+2'
|
||||
assert to_str(from_float(2.1287e14), 15, max_fixed=1000) == '212870000000000.0'
|
||||
assert to_str(from_float(2.1287e15), 15, max_fixed=1000) == '2128700000000000.0'
|
||||
assert to_str(from_float(2.1287e16), 15, max_fixed=1000) == '21287000000000000.0'
|
||||
assert to_str(from_float(2.1287e30), 15, max_fixed=1000) == '2128700000000000000000000000000.0'
|
||||
|
||||
def test_tight_string_conversion():
|
||||
mp.dps = 15
|
||||
# In an old version, '0.5' wasn't recognized as representing
|
||||
# an exact binary number and was erroneously rounded up or down
|
||||
assert from_str('0.5', 10, round_floor) == fhalf
|
||||
assert from_str('0.5', 10, round_ceiling) == fhalf
|
||||
|
||||
def test_eval_repr_invariant():
|
||||
"""Test that eval(repr(x)) == x"""
|
||||
random.seed(123)
|
||||
for dps in [10, 15, 20, 50, 100]:
|
||||
mp.dps = dps
|
||||
for i in xrange(1000):
|
||||
a = mpf(random.random())**0.5 * 10**random.randint(-100, 100)
|
||||
assert eval(repr(a)) == a
|
||||
mp.dps = 15
|
||||
|
||||
def test_str_bugs():
|
||||
mp.dps = 15
|
||||
# Decimal rounding used to give the wrong exponent in some cases
|
||||
assert str(mpf('1e600')) == '1.0e+600'
|
||||
assert str(mpf('1e10000')) == '1.0e+10000'
|
||||
|
||||
def test_str_prec0():
|
||||
assert to_str(from_float(1.234), 0) == '.0e+0'
|
||||
assert to_str(from_float(1e-15), 0) == '.0e-15'
|
||||
assert to_str(from_float(1e+15), 0) == '.0e+15'
|
||||
assert to_str(from_float(-1e-15), 0) == '-.0e-15'
|
||||
assert to_str(from_float(-1e+15), 0) == '-.0e+15'
|
||||
|
||||
def test_convert_rational():
|
||||
mp.dps = 15
|
||||
assert from_rational(30, 5, 53, round_nearest) == (0, 3, 1, 2)
|
||||
assert from_rational(-7, 4, 53, round_nearest) == (1, 7, -2, 3)
|
||||
assert to_rational((0, 1, -1, 1)) == (1, 2)
|
||||
|
||||
def test_custom_class():
|
||||
class mympf:
|
||||
@property
|
||||
def _mpf_(self):
|
||||
return mpf(3.5)._mpf_
|
||||
class mympc:
|
||||
@property
|
||||
def _mpc_(self):
|
||||
return mpf(3.5)._mpf_, mpf(2.5)._mpf_
|
||||
assert mpf(2) + mympf() == 5.5
|
||||
assert mympf() + mpf(2) == 5.5
|
||||
assert mpf(mympf()) == 3.5
|
||||
assert mympc() + mpc(2) == mpc(5.5, 2.5)
|
||||
assert mpc(2) + mympc() == mpc(5.5, 2.5)
|
||||
assert mpc(mympc()) == (3.5+2.5j)
|
||||
|
||||
def test_conversion_methods():
|
||||
class SomethingRandom:
|
||||
pass
|
||||
class SomethingReal:
|
||||
def _mpmath_(self, prec, rounding):
|
||||
return mp.make_mpf(from_str('1.3', prec, rounding))
|
||||
class SomethingComplex:
|
||||
def _mpmath_(self, prec, rounding):
|
||||
return mp.make_mpc((from_str('1.3', prec, rounding), \
|
||||
from_str('1.7', prec, rounding)))
|
||||
x = mpf(3)
|
||||
z = mpc(3)
|
||||
a = SomethingRandom()
|
||||
y = SomethingReal()
|
||||
w = SomethingComplex()
|
||||
for d in [15, 45]:
|
||||
mp.dps = d
|
||||
assert (x+y).ae(mpf('4.3'))
|
||||
assert (y+x).ae(mpf('4.3'))
|
||||
assert (x+w).ae(mpc('4.3', '1.7'))
|
||||
assert (w+x).ae(mpc('4.3', '1.7'))
|
||||
assert (z+y).ae(mpc('4.3'))
|
||||
assert (y+z).ae(mpc('4.3'))
|
||||
assert (z+w).ae(mpc('4.3', '1.7'))
|
||||
assert (w+z).ae(mpc('4.3', '1.7'))
|
||||
x-y; y-x; x-w; w-x; z-y; y-z; z-w; w-z
|
||||
x*y; y*x; x*w; w*x; z*y; y*z; z*w; w*z
|
||||
x/y; y/x; x/w; w/x; z/y; y/z; z/w; w/z
|
||||
x**y; y**x; x**w; w**x; z**y; y**z; z**w; w**z
|
||||
x==y; y==x; x==w; w==x; z==y; y==z; z==w; w==z
|
||||
mp.dps = 15
|
||||
assert x.__add__(a) is NotImplemented
|
||||
assert x.__radd__(a) is NotImplemented
|
||||
assert x.__lt__(a) is NotImplemented
|
||||
assert x.__gt__(a) is NotImplemented
|
||||
assert x.__le__(a) is NotImplemented
|
||||
assert x.__ge__(a) is NotImplemented
|
||||
assert x.__eq__(a) is NotImplemented
|
||||
assert x.__ne__(a) is NotImplemented
|
||||
# implementation detail
|
||||
if hasattr(x, "__cmp__"):
|
||||
assert x.__cmp__(a) is NotImplemented
|
||||
assert x.__sub__(a) is NotImplemented
|
||||
assert x.__rsub__(a) is NotImplemented
|
||||
assert x.__mul__(a) is NotImplemented
|
||||
assert x.__rmul__(a) is NotImplemented
|
||||
assert x.__div__(a) is NotImplemented
|
||||
assert x.__rdiv__(a) is NotImplemented
|
||||
assert x.__mod__(a) is NotImplemented
|
||||
assert x.__rmod__(a) is NotImplemented
|
||||
assert x.__pow__(a) is NotImplemented
|
||||
assert x.__rpow__(a) is NotImplemented
|
||||
assert z.__add__(a) is NotImplemented
|
||||
assert z.__radd__(a) is NotImplemented
|
||||
assert z.__eq__(a) is NotImplemented
|
||||
assert z.__ne__(a) is NotImplemented
|
||||
assert z.__sub__(a) is NotImplemented
|
||||
assert z.__rsub__(a) is NotImplemented
|
||||
assert z.__mul__(a) is NotImplemented
|
||||
assert z.__rmul__(a) is NotImplemented
|
||||
assert z.__div__(a) is NotImplemented
|
||||
assert z.__rdiv__(a) is NotImplemented
|
||||
assert z.__pow__(a) is NotImplemented
|
||||
assert z.__rpow__(a) is NotImplemented
|
||||
|
||||
def test_mpmathify():
|
||||
assert mpmathify('1/2') == 0.5
|
||||
assert mpmathify('(1.0+1.0j)') == mpc(1, 1)
|
||||
assert mpmathify('(1.2e-10 - 3.4e5j)') == mpc('1.2e-10', '-3.4e5')
|
||||
assert mpmathify('1j') == mpc(1j)
|
||||
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_diff():
|
||||
assert diff(log, 2.0, n=0).ae(log(2))
|
||||
assert diff(cos, 1.0).ae(-sin(1))
|
||||
assert diff(abs, 0.0) == 0
|
||||
assert diff(abs, 0.0, direction=1) == 1
|
||||
assert diff(abs, 0.0, direction=-1) == -1
|
||||
assert diff(exp, 1.0).ae(e)
|
||||
assert diff(exp, 1.0, n=5).ae(e)
|
||||
assert diff(exp, 2.0, n=5, direction=3*j).ae(e**2)
|
||||
assert diff(lambda x: x**2, 3.0, method='quad').ae(6)
|
||||
assert diff(lambda x: 3+x**5, 3.0, n=2, method='quad').ae(540)
|
||||
assert diff(lambda x: 3+x**5, 3.0, n=2, method='step').ae(540)
|
||||
assert diffun(sin)(2).ae(cos(2))
|
||||
assert diffun(sin, n=2)(2).ae(-sin(2))
|
||||
|
||||
def test_taylor():
|
||||
# Easy to test since the coefficients are exact in floating-point
|
||||
assert taylor(sqrt, 1, 4) == [1, 0.5, -0.125, 0.0625, -0.0390625]
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
from mpmath.libmp import *
|
||||
from mpmath import mpf, mp
|
||||
|
||||
from random import randint, choice, seed
|
||||
|
||||
all_modes = [round_floor, round_ceiling, round_down, round_up, round_nearest]
|
||||
|
||||
fb = from_bstr
|
||||
fi = from_int
|
||||
ff = from_float
|
||||
|
||||
|
||||
def test_div_1_3():
|
||||
a = fi(1)
|
||||
b = fi(3)
|
||||
c = fi(-1)
|
||||
|
||||
# floor rounds down, ceiling rounds up
|
||||
assert mpf_div(a, b, 7, round_floor) == fb('0.01010101')
|
||||
assert mpf_div(a, b, 7, round_ceiling) == fb('0.01010110')
|
||||
assert mpf_div(a, b, 7, round_down) == fb('0.01010101')
|
||||
assert mpf_div(a, b, 7, round_up) == fb('0.01010110')
|
||||
assert mpf_div(a, b, 7, round_nearest) == fb('0.01010101')
|
||||
|
||||
# floor rounds up, ceiling rounds down
|
||||
assert mpf_div(c, b, 7, round_floor) == fb('-0.01010110')
|
||||
assert mpf_div(c, b, 7, round_ceiling) == fb('-0.01010101')
|
||||
assert mpf_div(c, b, 7, round_down) == fb('-0.01010101')
|
||||
assert mpf_div(c, b, 7, round_up) == fb('-0.01010110')
|
||||
assert mpf_div(c, b, 7, round_nearest) == fb('-0.01010101')
|
||||
|
||||
def test_mpf_divi_1_3():
|
||||
a = 1
|
||||
b = fi(3)
|
||||
c = -1
|
||||
assert mpf_rdiv_int(a, b, 7, round_floor) == fb('0.01010101')
|
||||
assert mpf_rdiv_int(a, b, 7, round_ceiling) == fb('0.01010110')
|
||||
assert mpf_rdiv_int(a, b, 7, round_down) == fb('0.01010101')
|
||||
assert mpf_rdiv_int(a, b, 7, round_up) == fb('0.01010110')
|
||||
assert mpf_rdiv_int(a, b, 7, round_nearest) == fb('0.01010101')
|
||||
assert mpf_rdiv_int(c, b, 7, round_floor) == fb('-0.01010110')
|
||||
assert mpf_rdiv_int(c, b, 7, round_ceiling) == fb('-0.01010101')
|
||||
assert mpf_rdiv_int(c, b, 7, round_down) == fb('-0.01010101')
|
||||
assert mpf_rdiv_int(c, b, 7, round_up) == fb('-0.01010110')
|
||||
assert mpf_rdiv_int(c, b, 7, round_nearest) == fb('-0.01010101')
|
||||
|
||||
|
||||
def test_div_300():
|
||||
|
||||
q = fi(1000000)
|
||||
a = fi(300499999) # a/q is a little less than a half-integer
|
||||
b = fi(300500000) # b/q exactly a half-integer
|
||||
c = fi(300500001) # c/q is a little more than a half-integer
|
||||
|
||||
# Check nearest integer rounding (prec=9 as 2**8 < 300 < 2**9)
|
||||
|
||||
assert mpf_div(a, q, 9, round_down) == fi(300)
|
||||
assert mpf_div(b, q, 9, round_down) == fi(300)
|
||||
assert mpf_div(c, q, 9, round_down) == fi(300)
|
||||
assert mpf_div(a, q, 9, round_up) == fi(301)
|
||||
assert mpf_div(b, q, 9, round_up) == fi(301)
|
||||
assert mpf_div(c, q, 9, round_up) == fi(301)
|
||||
|
||||
# Nearest even integer is down
|
||||
assert mpf_div(a, q, 9, round_nearest) == fi(300)
|
||||
assert mpf_div(b, q, 9, round_nearest) == fi(300)
|
||||
assert mpf_div(c, q, 9, round_nearest) == fi(301)
|
||||
|
||||
# Nearest even integer is up
|
||||
a = fi(301499999)
|
||||
b = fi(301500000)
|
||||
c = fi(301500001)
|
||||
assert mpf_div(a, q, 9, round_nearest) == fi(301)
|
||||
assert mpf_div(b, q, 9, round_nearest) == fi(302)
|
||||
assert mpf_div(c, q, 9, round_nearest) == fi(302)
|
||||
|
||||
|
||||
def test_tight_integer_division():
|
||||
# Test that integer division at tightest possible precision is exact
|
||||
N = 100
|
||||
seed(1)
|
||||
for i in range(N):
|
||||
a = choice([1, -1]) * randint(1, 1<<randint(10, 100))
|
||||
b = choice([1, -1]) * randint(1, 1<<randint(10, 100))
|
||||
p = a * b
|
||||
width = bitcount(abs(b)) - trailing(b)
|
||||
a = fi(a); b = fi(b); p = fi(p)
|
||||
for mode in all_modes:
|
||||
assert mpf_div(p, a, width, mode) == b
|
||||
|
||||
|
||||
def test_epsilon_rounding():
|
||||
# Verify that mpf_div uses infinite precision; this result will
|
||||
# appear to be exactly 0.101 to a near-sighted algorithm
|
||||
|
||||
a = fb('0.101' + ('0'*200) + '1')
|
||||
b = fb('1.10101')
|
||||
c = mpf_mul(a, b, 250, round_floor) # exact
|
||||
assert mpf_div(c, b, bitcount(a[1]), round_floor) == a # exact
|
||||
|
||||
assert mpf_div(c, b, 2, round_down) == fb('0.10')
|
||||
assert mpf_div(c, b, 3, round_down) == fb('0.101')
|
||||
assert mpf_div(c, b, 2, round_up) == fb('0.11')
|
||||
assert mpf_div(c, b, 3, round_up) == fb('0.110')
|
||||
assert mpf_div(c, b, 2, round_floor) == fb('0.10')
|
||||
assert mpf_div(c, b, 3, round_floor) == fb('0.101')
|
||||
assert mpf_div(c, b, 2, round_ceiling) == fb('0.11')
|
||||
assert mpf_div(c, b, 3, round_ceiling) == fb('0.110')
|
||||
|
||||
# The same for negative numbers
|
||||
a = fb('-0.101' + ('0'*200) + '1')
|
||||
b = fb('1.10101')
|
||||
c = mpf_mul(a, b, 250, round_floor)
|
||||
assert mpf_div(c, b, bitcount(a[1]), round_floor) == a
|
||||
|
||||
assert mpf_div(c, b, 2, round_down) == fb('-0.10')
|
||||
assert mpf_div(c, b, 3, round_up) == fb('-0.110')
|
||||
|
||||
# Floor goes up, ceiling goes down
|
||||
assert mpf_div(c, b, 2, round_floor) == fb('-0.11')
|
||||
assert mpf_div(c, b, 3, round_floor) == fb('-0.110')
|
||||
assert mpf_div(c, b, 2, round_ceiling) == fb('-0.10')
|
||||
assert mpf_div(c, b, 3, round_ceiling) == fb('-0.101')
|
||||
|
||||
|
||||
def test_mod():
|
||||
mp.dps = 15
|
||||
assert mpf(234) % 1 == 0
|
||||
assert mpf(-3) % 256 == 253
|
||||
assert mpf(0.25) % 23490.5 == 0.25
|
||||
assert mpf(0.25) % -23490.5 == -23490.25
|
||||
assert mpf(-0.25) % 23490.5 == 23490.25
|
||||
assert mpf(-0.25) % -23490.5 == -0.25
|
||||
# Check that these cases are handled efficiently
|
||||
assert mpf('1e10000000000') % 1 == 0
|
||||
assert mpf('1.23e-1000000000') % 1 == mpf('1.23e-1000000000')
|
||||
# test __rmod__
|
||||
assert 3 % mpf('1.75') == 1.25
|
||||
|
||||
def test_div_negative_rnd_bug():
|
||||
mp.dps = 15
|
||||
assert (-3) / mpf('0.1531879017645047') == mpf('-19.583791966887116')
|
||||
assert mpf('-2.6342475750861301') / mpf('0.35126216427941814') == mpf('-7.4993775104985909')
|
||||
|
|
@ -1,537 +0,0 @@
|
|||
"""
|
||||
Limited tests of the elliptic functions module. A full suite of
|
||||
extensive testing can be found in elliptic_torture_tests.py
|
||||
|
||||
Author of the first version: M.T. Taschuk
|
||||
|
||||
References:
|
||||
|
||||
[1] Abramowitz & Stegun. 'Handbook of Mathematical Functions, 9th Ed.',
|
||||
(Dover duplicate of 1972 edition)
|
||||
[2] Whittaker 'A Course of Modern Analysis, 4th Ed.', 1946,
|
||||
Cambridge University Press
|
||||
|
||||
"""
|
||||
|
||||
import mpmath
|
||||
import random
|
||||
|
||||
from mpmath import *
|
||||
|
||||
def mpc_ae(a, b, eps=eps):
|
||||
res = True
|
||||
res = res and a.real.ae(b.real, eps)
|
||||
res = res and a.imag.ae(b.imag, eps)
|
||||
return res
|
||||
|
||||
zero = mpf(0)
|
||||
one = mpf(1)
|
||||
|
||||
def test_calculate_nome():
|
||||
mp.dps = 100
|
||||
|
||||
q = calculate_nome(zero)
|
||||
assert(q == zero)
|
||||
|
||||
mp.dps = 25
|
||||
# used Mathematica's EllipticNomeQ[m]
|
||||
math1 = [(mpf(1)/10, mpf('0.006584651553858370274473060')),
|
||||
(mpf(2)/10, mpf('0.01394285727531826872146409')),
|
||||
(mpf(3)/10, mpf('0.02227743615715350822901627')),
|
||||
(mpf(4)/10, mpf('0.03188334731336317755064299')),
|
||||
(mpf(5)/10, mpf('0.04321391826377224977441774')),
|
||||
(mpf(6)/10, mpf('0.05702025781460967637754953')),
|
||||
(mpf(7)/10, mpf('0.07468994353717944761143751')),
|
||||
(mpf(8)/10, mpf('0.09927369733882489703607378')),
|
||||
(mpf(9)/10, mpf('0.1401731269542615524091055')),
|
||||
(mpf(9)/10, mpf('0.1401731269542615524091055'))]
|
||||
|
||||
for i in math1:
|
||||
m = i[0]
|
||||
q = calculate_nome(sqrt(m))
|
||||
assert q.ae(i[1])
|
||||
|
||||
mp.dps = 15
|
||||
|
||||
def test_jtheta():
|
||||
mp.dps = 25
|
||||
|
||||
z = q = zero
|
||||
for n in range(1,5):
|
||||
value = jtheta(n, z, q)
|
||||
assert(value == (n-1)//2)
|
||||
|
||||
for q in [one, mpf(2)]:
|
||||
for n in range(1,5):
|
||||
raised = True
|
||||
try:
|
||||
r = jtheta(n, z, q)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
raised = False
|
||||
assert(raised)
|
||||
|
||||
z = one/10
|
||||
q = one/11
|
||||
|
||||
# Mathematical N[EllipticTheta[1, 1/10, 1/11], 25]
|
||||
res = mpf('0.1069552990104042681962096')
|
||||
result = jtheta(1, z, q)
|
||||
assert(result.ae(res))
|
||||
|
||||
# Mathematica N[EllipticTheta[2, 1/10, 1/11], 25]
|
||||
res = mpf('1.101385760258855791140606')
|
||||
result = jtheta(2, z, q)
|
||||
assert(result.ae(res))
|
||||
|
||||
# Mathematica N[EllipticTheta[3, 1/10, 1/11], 25]
|
||||
res = mpf('1.178319743354331061795905')
|
||||
result = jtheta(3, z, q)
|
||||
assert(result.ae(res))
|
||||
|
||||
# Mathematica N[EllipticTheta[4, 1/10, 1/11], 25]
|
||||
res = mpf('0.8219318954665153577314573')
|
||||
result = jtheta(4, z, q)
|
||||
assert(result.ae(res))
|
||||
|
||||
# test for sin zeros for jtheta(1, z, q)
|
||||
# test for cos zeros for jtheta(2, z, q)
|
||||
z1 = pi
|
||||
z2 = pi/2
|
||||
for i in range(10):
|
||||
qstring = str(random.random())
|
||||
q = mpf(qstring)
|
||||
result = jtheta(1, z1, q)
|
||||
assert(result.ae(0))
|
||||
result = jtheta(2, z2, q)
|
||||
assert(result.ae(0))
|
||||
mp.dps = 15
|
||||
|
||||
|
||||
def test_jtheta_issue39():
|
||||
# near the circle of covergence |q| = 1 the convergence slows
|
||||
# down; for |q| > Q_LIM the theta functions raise ValueError
|
||||
mp.dps = 30
|
||||
mp.dps += 30
|
||||
q = mpf(6)/10 - one/10**6 - mpf(8)/10 * j
|
||||
mp.dps -= 30
|
||||
# Mathematica run first
|
||||
# N[EllipticTheta[3, 1, 6/10 - 10^-6 - 8/10*I], 2000]
|
||||
# then it works:
|
||||
# N[EllipticTheta[3, 1, 6/10 - 10^-6 - 8/10*I], 30]
|
||||
res = mpf('32.0031009628901652627099524264') + \
|
||||
mpf('16.6153027998236087899308935624') * j
|
||||
result = jtheta(3, 1, q)
|
||||
# check that for abs(q) > Q_LIM a ValueError exception is raised
|
||||
mp.dps += 30
|
||||
q = mpf(6)/10 - one/10**7 - mpf(8)/10 * j
|
||||
mp.dps -= 30
|
||||
try:
|
||||
result = jtheta(3, 1, q)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert(False)
|
||||
|
||||
# bug reported in issue39
|
||||
mp.dps = 100
|
||||
z = (1+j)/3
|
||||
q = mpf(368983957219251)/10**15 + mpf(636363636363636)/10**15 * j
|
||||
# Mathematica N[EllipticTheta[1, z, q], 35]
|
||||
res = mpf('2.4439389177990737589761828991467471') + \
|
||||
mpf('0.5446453005688226915290954851851490') *j
|
||||
mp.dps = 30
|
||||
result = jtheta(1, z, q)
|
||||
assert(result.ae(res))
|
||||
mp.dps = 80
|
||||
z = 3 + 4*j
|
||||
q = 0.5 + 0.5*j
|
||||
r1 = jtheta(1, z, q)
|
||||
mp.dps = 15
|
||||
r2 = jtheta(1, z, q)
|
||||
assert r1.ae(r2)
|
||||
mp.dps = 80
|
||||
z = 3 + j
|
||||
q1 = exp(j*3)
|
||||
# longer test
|
||||
# for n in range(1, 6)
|
||||
for n in range(1, 2):
|
||||
mp.dps = 80
|
||||
q = q1*(1 - mpf(1)/10**n)
|
||||
r1 = jtheta(1, z, q)
|
||||
mp.dps = 15
|
||||
r2 = jtheta(1, z, q)
|
||||
assert r1.ae(r2)
|
||||
mp.dps = 15
|
||||
# issue 39 about high derivatives
|
||||
assert jtheta(3, 4.5, 0.25, 9).ae(1359.04892680683)
|
||||
assert jtheta(3, 4.5, 0.25, 50).ae(-6.14832772630905e+33)
|
||||
mp.dps = 50
|
||||
r = jtheta(3, 4.5, 0.25, 9)
|
||||
assert r.ae('1359.048926806828939547859396600218966947753213803')
|
||||
r = jtheta(3, 4.5, 0.25, 50)
|
||||
assert r.ae('-6148327726309051673317975084654262.4119215720343656')
|
||||
|
||||
def test_jtheta_identities():
|
||||
"""
|
||||
Tests the some of the jacobi identidies found in Abramowitz,
|
||||
Sec. 16.28, Pg. 576. The identities are tested to 1 part in 10^98.
|
||||
"""
|
||||
mp.dps = 110
|
||||
eps1 = ldexp(eps, 30)
|
||||
|
||||
for i in range(10):
|
||||
qstring = str(random.random())
|
||||
q = mpf(qstring)
|
||||
|
||||
zstring = str(10*random.random())
|
||||
z = mpf(zstring)
|
||||
# Abramowitz 16.28.1
|
||||
# v_1(z, q)**2 * v_4(0, q)**2 = v_3(z, q)**2 * v_2(0, q)**2
|
||||
# - v_2(z, q)**2 * v_3(0, q)**2
|
||||
term1 = (jtheta(1, z, q)**2) * (jtheta(4, zero, q)**2)
|
||||
term2 = (jtheta(3, z, q)**2) * (jtheta(2, zero, q)**2)
|
||||
term3 = (jtheta(2, z, q)**2) * (jtheta(3, zero, q)**2)
|
||||
equality = term1 - term2 + term3
|
||||
assert(equality.ae(0, eps1))
|
||||
|
||||
zstring = str(100*random.random())
|
||||
z = mpf(zstring)
|
||||
# Abramowitz 16.28.2
|
||||
# v_2(z, q)**2 * v_4(0, q)**2 = v_4(z, q)**2 * v_2(0, q)**2
|
||||
# - v_1(z, q)**2 * v_3(0, q)**2
|
||||
term1 = (jtheta(2, z, q)**2) * (jtheta(4, zero, q)**2)
|
||||
term2 = (jtheta(4, z, q)**2) * (jtheta(2, zero, q)**2)
|
||||
term3 = (jtheta(1, z, q)**2) * (jtheta(3, zero, q)**2)
|
||||
equality = term1 - term2 + term3
|
||||
assert(equality.ae(0, eps1))
|
||||
|
||||
# Abramowitz 16.28.3
|
||||
# v_3(z, q)**2 * v_4(0, q)**2 = v_4(z, q)**2 * v_3(0, q)**2
|
||||
# - v_1(z, q)**2 * v_2(0, q)**2
|
||||
term1 = (jtheta(3, z, q)**2) * (jtheta(4, zero, q)**2)
|
||||
term2 = (jtheta(4, z, q)**2) * (jtheta(3, zero, q)**2)
|
||||
term3 = (jtheta(1, z, q)**2) * (jtheta(2, zero, q)**2)
|
||||
equality = term1 - term2 + term3
|
||||
assert(equality.ae(0, eps1))
|
||||
|
||||
# Abramowitz 16.28.4
|
||||
# v_4(z, q)**2 * v_4(0, q)**2 = v_3(z, q)**2 * v_3(0, q)**2
|
||||
# - v_2(z, q)**2 * v_2(0, q)**2
|
||||
term1 = (jtheta(4, z, q)**2) * (jtheta(4, zero, q)**2)
|
||||
term2 = (jtheta(3, z, q)**2) * (jtheta(3, zero, q)**2)
|
||||
term3 = (jtheta(2, z, q)**2) * (jtheta(2, zero, q)**2)
|
||||
equality = term1 - term2 + term3
|
||||
assert(equality.ae(0, eps1))
|
||||
|
||||
# Abramowitz 16.28.5
|
||||
# v_2(0, q)**4 + v_4(0, q)**4 == v_3(0, q)**4
|
||||
term1 = (jtheta(2, zero, q))**4
|
||||
term2 = (jtheta(4, zero, q))**4
|
||||
term3 = (jtheta(3, zero, q))**4
|
||||
equality = term1 + term2 - term3
|
||||
assert(equality.ae(0, eps1))
|
||||
mp.dps = 15
|
||||
|
||||
def test_jtheta_complex():
|
||||
mp.dps = 30
|
||||
z = mpf(1)/4 + j/8
|
||||
q = mpf(1)/3 + j/7
|
||||
# Mathematica N[EllipticTheta[1, 1/4 + I/8, 1/3 + I/7], 35]
|
||||
res = mpf('0.31618034835986160705729105731678285') + \
|
||||
mpf('0.07542013825835103435142515194358975') * j
|
||||
r = jtheta(1, z, q)
|
||||
assert(mpc_ae(r, res))
|
||||
|
||||
# Mathematica N[EllipticTheta[2, 1/4 + I/8, 1/3 + I/7], 35]
|
||||
res = mpf('1.6530986428239765928634711417951828') + \
|
||||
mpf('0.2015344864707197230526742145361455') * j
|
||||
r = jtheta(2, z, q)
|
||||
assert(mpc_ae(r, res))
|
||||
|
||||
# Mathematica N[EllipticTheta[3, 1/4 + I/8, 1/3 + I/7], 35]
|
||||
res = mpf('1.6520564411784228184326012700348340') + \
|
||||
mpf('0.1998129119671271328684690067401823') * j
|
||||
r = jtheta(3, z, q)
|
||||
assert(mpc_ae(r, res))
|
||||
|
||||
# Mathematica N[EllipticTheta[4, 1/4 + I/8, 1/3 + I/7], 35]
|
||||
res = mpf('0.37619082382228348252047624089973824') - \
|
||||
mpf('0.15623022130983652972686227200681074') * j
|
||||
r = jtheta(4, z, q)
|
||||
assert(mpc_ae(r, res))
|
||||
|
||||
# check some theta function identities
|
||||
mp.dos = 100
|
||||
z = mpf(1)/4 + j/8
|
||||
q = mpf(1)/3 + j/7
|
||||
mp.dps += 10
|
||||
a = [0,0, jtheta(2, 0, q), jtheta(3, 0, q), jtheta(4, 0, q)]
|
||||
t = [0, jtheta(1, z, q), jtheta(2, z, q), jtheta(3, z, q), jtheta(4, z, q)]
|
||||
r = [(t[2]*a[4])**2 - (t[4]*a[2])**2 + (t[1] *a[3])**2,
|
||||
(t[3]*a[4])**2 - (t[4]*a[3])**2 + (t[1] *a[2])**2,
|
||||
(t[1]*a[4])**2 - (t[3]*a[2])**2 + (t[2] *a[3])**2,
|
||||
(t[4]*a[4])**2 - (t[3]*a[3])**2 + (t[2] *a[2])**2,
|
||||
a[2]**4 + a[4]**4 - a[3]**4]
|
||||
mp.dps -= 10
|
||||
for x in r:
|
||||
assert(mpc_ae(x, mpc(0)))
|
||||
mp.dps = 15
|
||||
|
||||
def test_djtheta():
|
||||
mp.dps = 30
|
||||
|
||||
z = one/7 + j/3
|
||||
q = one/8 + j/5
|
||||
# Mathematica N[EllipticThetaPrime[1, 1/7 + I/3, 1/8 + I/5], 35]
|
||||
res = mpf('1.5555195883277196036090928995803201') - \
|
||||
mpf('0.02439761276895463494054149673076275') * j
|
||||
result = jtheta(1, z, q, 1)
|
||||
assert(mpc_ae(result, res))
|
||||
|
||||
# Mathematica N[EllipticThetaPrime[2, 1/7 + I/3, 1/8 + I/5], 35]
|
||||
res = mpf('0.19825296689470982332701283509685662') - \
|
||||
mpf('0.46038135182282106983251742935250009') * j
|
||||
result = jtheta(2, z, q, 1)
|
||||
assert(mpc_ae(result, res))
|
||||
|
||||
# Mathematica N[EllipticThetaPrime[3, 1/7 + I/3, 1/8 + I/5], 35]
|
||||
res = mpf('0.36492498415476212680896699407390026') - \
|
||||
mpf('0.57743812698666990209897034525640369') * j
|
||||
result = jtheta(3, z, q, 1)
|
||||
assert(mpc_ae(result, res))
|
||||
|
||||
# Mathematica N[EllipticThetaPrime[4, 1/7 + I/3, 1/8 + I/5], 35]
|
||||
res = mpf('-0.38936892528126996010818803742007352') + \
|
||||
mpf('0.66549886179739128256269617407313625') * j
|
||||
result = jtheta(4, z, q, 1)
|
||||
assert(mpc_ae(result, res))
|
||||
|
||||
for i in range(10):
|
||||
q = (one*random.random() + j*random.random())/2
|
||||
# identity in Wittaker, Watson &21.41
|
||||
a = jtheta(1, 0, q, 1)
|
||||
b = jtheta(2, 0, q)*jtheta(3, 0, q)*jtheta(4, 0, q)
|
||||
assert(a.ae(b))
|
||||
|
||||
# test higher derivatives
|
||||
mp.dps = 20
|
||||
for q,z in [(one/3, one/5), (one/3 + j/8, one/5),
|
||||
(one/3, one/5 + j/8), (one/3 + j/7, one/5 + j/8)]:
|
||||
for n in [1, 2, 3, 4]:
|
||||
r = jtheta(n, z, q, 2)
|
||||
r1 = diff(lambda zz: jtheta(n, zz, q), z, n=2)
|
||||
assert r.ae(r1)
|
||||
r = jtheta(n, z, q, 3)
|
||||
r1 = diff(lambda zz: jtheta(n, zz, q), z, n=3)
|
||||
assert r.ae(r1)
|
||||
|
||||
# identity in Wittaker, Watson &21.41
|
||||
q = one/3
|
||||
z = zero
|
||||
a = [0]*5
|
||||
a[1] = jtheta(1, z, q, 3)/jtheta(1, z, q, 1)
|
||||
for n in [2,3,4]:
|
||||
a[n] = jtheta(n, z, q, 2)/jtheta(n, z, q)
|
||||
equality = a[2] + a[3] + a[4] - a[1]
|
||||
assert(equality.ae(0))
|
||||
mp.dps = 15
|
||||
|
||||
def test_jsn():
|
||||
"""
|
||||
Test some special cases of the sn(z, q) function.
|
||||
"""
|
||||
mp.dps = 100
|
||||
|
||||
# trival case
|
||||
result = jsn(zero, zero)
|
||||
assert(result == zero)
|
||||
|
||||
# Abramowitz Table 16.5
|
||||
#
|
||||
# sn(0, m) = 0
|
||||
|
||||
for i in range(10):
|
||||
qstring = str(random.random())
|
||||
q = mpf(qstring)
|
||||
|
||||
equality = jsn(zero, q)
|
||||
assert(equality.ae(0))
|
||||
|
||||
# Abramowitz Table 16.6.1
|
||||
#
|
||||
# sn(z, 0) = sin(z), m == 0
|
||||
#
|
||||
# sn(z, 1) = tanh(z), m == 1
|
||||
#
|
||||
# It would be nice to test these, but I find that they run
|
||||
# in to numerical trouble. I'm currently treating as a boundary
|
||||
# case for sn function.
|
||||
|
||||
mp.dps = 25
|
||||
arg = one/10
|
||||
#N[JacobiSN[1/10, 2^-100], 25]
|
||||
res = mpf('0.09983341664682815230681420')
|
||||
m = ldexp(one, -100)
|
||||
result = jsn(arg, m)
|
||||
assert(result.ae(res))
|
||||
|
||||
# N[JacobiSN[1/10, 1/10], 25]
|
||||
res = mpf('0.09981686718599080096451168')
|
||||
result = jsn(arg, arg)
|
||||
assert(result.ae(res))
|
||||
mp.dps = 15
|
||||
|
||||
def test_jcn():
|
||||
"""
|
||||
Test some special cases of the cn(z, q) function.
|
||||
"""
|
||||
mp.dps = 100
|
||||
|
||||
# Abramowitz Table 16.5
|
||||
# cn(0, q) = 1
|
||||
qstring = str(random.random())
|
||||
q = mpf(qstring)
|
||||
cn = jcn(zero, q)
|
||||
assert(cn.ae(one))
|
||||
|
||||
# Abramowitz Table 16.6.2
|
||||
#
|
||||
# cn(u, 0) = cos(u), m == 0
|
||||
#
|
||||
# cn(u, 1) = sech(z), m == 1
|
||||
#
|
||||
# It would be nice to test these, but I find that they run
|
||||
# in to numerical trouble. I'm currently treating as a boundary
|
||||
# case for cn function.
|
||||
|
||||
mp.dps = 25
|
||||
arg = one/10
|
||||
m = ldexp(one, -100)
|
||||
#N[JacobiCN[1/10, 2^-100], 25]
|
||||
res = mpf('0.9950041652780257660955620')
|
||||
result = jcn(arg, m)
|
||||
assert(result.ae(res))
|
||||
|
||||
# N[JacobiCN[1/10, 1/10], 25]
|
||||
res = mpf('0.9950058256237368748520459')
|
||||
result = jcn(arg, arg)
|
||||
assert(result.ae(res))
|
||||
mp.dps = 15
|
||||
|
||||
def test_jdn():
|
||||
"""
|
||||
Test some special cases of the dn(z, q) function.
|
||||
"""
|
||||
mp.dps = 100
|
||||
|
||||
# Abramowitz Table 16.5
|
||||
# dn(0, q) = 1
|
||||
mstring = str(random.random())
|
||||
m = mpf(mstring)
|
||||
|
||||
dn = jdn(zero, m)
|
||||
assert(dn.ae(one))
|
||||
|
||||
mp.dps = 25
|
||||
# N[JacobiDN[1/10, 1/10], 25]
|
||||
res = mpf('0.9995017055025556219713297')
|
||||
arg = one/10
|
||||
result = jdn(arg, arg)
|
||||
assert(result.ae(res))
|
||||
mp.dps = 15
|
||||
|
||||
|
||||
def test_sn_cn_dn_identities():
|
||||
"""
|
||||
Tests the some of the jacobi elliptic function identities found
|
||||
on Mathworld. Haven't found in Abramowitz.
|
||||
"""
|
||||
mp.dps = 100
|
||||
N = 5
|
||||
for i in range(N):
|
||||
qstring = str(random.random())
|
||||
q = mpf(qstring)
|
||||
zstring = str(100*random.random())
|
||||
z = mpf(zstring)
|
||||
|
||||
# MathWorld
|
||||
# sn(z, q)**2 + cn(z, q)**2 == 1
|
||||
term1 = jsn(z, q)**2
|
||||
term2 = jcn(z, q)**2
|
||||
equality = one - term1 - term2
|
||||
assert(equality.ae(0))
|
||||
|
||||
# MathWorld
|
||||
# k**2 * sn(z, m)**2 + dn(z, m)**2 == 1
|
||||
for i in range(N):
|
||||
mstring = str(random.random())
|
||||
m = mpf(qstring)
|
||||
k = m.sqrt()
|
||||
zstring = str(10*random.random())
|
||||
z = mpf(zstring)
|
||||
term1 = k**2 * jsn(z, m)**2
|
||||
term2 = jdn(z, m)**2
|
||||
equality = one - term1 - term2
|
||||
assert(equality.ae(0))
|
||||
|
||||
|
||||
for i in range(N):
|
||||
mstring = str(random.random())
|
||||
m = mpf(mstring)
|
||||
k = m.sqrt()
|
||||
zstring = str(random.random())
|
||||
z = mpf(zstring)
|
||||
|
||||
# MathWorld
|
||||
# k**2 * cn(z, m)**2 + (1 - k**2) = dn(z, m)**2
|
||||
term1 = k**2 * jcn(z, m)**2
|
||||
term2 = 1 - k**2
|
||||
term3 = jdn(z, m)**2
|
||||
equality = term3 - term1 - term2
|
||||
assert(equality.ae(0))
|
||||
|
||||
K = ellipk(k**2)
|
||||
# Abramowitz Table 16.5
|
||||
# sn(K, m) = 1; K is K(k), first complete elliptic integral
|
||||
r = jsn(K, m)
|
||||
assert(r.ae(one))
|
||||
|
||||
# Abramowitz Table 16.5
|
||||
# cn(K, q) = 0; K is K(k), first complete elliptic integral
|
||||
equality = jcn(K, m)
|
||||
assert(equality.ae(0))
|
||||
|
||||
# Abramowitz Table 16.6.3
|
||||
# dn(z, 0) = 1, m == 0
|
||||
z = m
|
||||
value = jdn(z, zero)
|
||||
assert(value.ae(one))
|
||||
|
||||
mp.dps = 15
|
||||
|
||||
def test_sn_cn_dn_complex():
|
||||
mp.dps = 30
|
||||
# N[JacobiSN[1/4 + I/8, 1/3 + I/7], 35] in Mathematica
|
||||
res = mpf('0.2495674401066275492326652143537') + \
|
||||
mpf('0.12017344422863833381301051702823') * j
|
||||
u = mpf(1)/4 + j/8
|
||||
m = mpf(1)/3 + j/7
|
||||
r = jsn(u, m)
|
||||
assert(mpc_ae(r, res))
|
||||
|
||||
#N[JacobiCN[1/4 + I/8, 1/3 + I/7], 35]
|
||||
res = mpf('0.9762691700944007312693721148331') - \
|
||||
mpf('0.0307203994181623243583169154824')*j
|
||||
r = jcn(u, m)
|
||||
#assert r.real.ae(res.real)
|
||||
#assert r.imag.ae(res.imag)
|
||||
assert(mpc_ae(r, res))
|
||||
|
||||
#N[JacobiDN[1/4 + I/8, 1/3 + I/7], 35]
|
||||
res = mpf('0.99639490163039577560547478589753039') - \
|
||||
mpf('0.01346296520008176393432491077244994')*j
|
||||
r = jdn(u, m)
|
||||
assert(mpc_ae(r, res))
|
||||
mp.dps = 15
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,882 +0,0 @@
|
|||
from mpmath.libmp import *
|
||||
from mpmath import *
|
||||
import random
|
||||
import time
|
||||
import math
|
||||
import cmath
|
||||
|
||||
def mpc_ae(a, b, eps=eps):
|
||||
res = True
|
||||
res = res and a.real.ae(b.real, eps)
|
||||
res = res and a.imag.ae(b.imag, eps)
|
||||
return res
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Constants and functions
|
||||
#
|
||||
|
||||
tpi = "3.1415926535897932384626433832795028841971693993751058209749445923078\
|
||||
1640628620899862803482534211706798"
|
||||
te = "2.71828182845904523536028747135266249775724709369995957496696762772407\
|
||||
663035354759457138217852516642743"
|
||||
tdegree = "0.017453292519943295769236907684886127134428718885417254560971914\
|
||||
4017100911460344944368224156963450948221"
|
||||
teuler = "0.5772156649015328606065120900824024310421593359399235988057672348\
|
||||
84867726777664670936947063291746749516"
|
||||
tln2 = "0.693147180559945309417232121458176568075500134360255254120680009493\
|
||||
393621969694715605863326996418687542"
|
||||
tln10 = "2.30258509299404568401799145468436420760110148862877297603332790096\
|
||||
757260967735248023599720508959829834"
|
||||
tcatalan = "0.91596559417721901505460351493238411077414937428167213426649811\
|
||||
9621763019776254769479356512926115106249"
|
||||
tkhinchin = "2.6854520010653064453097148354817956938203822939944629530511523\
|
||||
4555721885953715200280114117493184769800"
|
||||
tglaisher = "1.2824271291006226368753425688697917277676889273250011920637400\
|
||||
2174040630885882646112973649195820237439420646"
|
||||
tapery = "1.2020569031595942853997381615114499907649862923404988817922715553\
|
||||
4183820578631309018645587360933525815"
|
||||
tphi = "1.618033988749894848204586834365638117720309179805762862135448622705\
|
||||
26046281890244970720720418939113748475"
|
||||
tmertens = "0.26149721284764278375542683860869585905156664826119920619206421\
|
||||
3924924510897368209714142631434246651052"
|
||||
ttwinprime = "0.660161815846869573927812110014555778432623360284733413319448\
|
||||
423335405642304495277143760031413839867912"
|
||||
|
||||
def test_constants():
|
||||
for prec in [3, 7, 10, 15, 20, 37, 80, 100, 29]:
|
||||
mp.dps = prec
|
||||
assert pi == mpf(tpi)
|
||||
assert e == mpf(te)
|
||||
assert degree == mpf(tdegree)
|
||||
assert euler == mpf(teuler)
|
||||
assert ln2 == mpf(tln2)
|
||||
assert ln10 == mpf(tln10)
|
||||
assert catalan == mpf(tcatalan)
|
||||
assert khinchin == mpf(tkhinchin)
|
||||
assert glaisher == mpf(tglaisher)
|
||||
assert phi == mpf(tphi)
|
||||
if prec < 50:
|
||||
assert mertens == mpf(tmertens)
|
||||
assert twinprime == mpf(ttwinprime)
|
||||
mp.dps = 15
|
||||
assert pi >= -1
|
||||
assert pi > 2
|
||||
assert pi > 3
|
||||
assert pi < 4
|
||||
|
||||
def test_exact_sqrts():
|
||||
for i in range(20000):
|
||||
assert sqrt(mpf(i*i)) == i
|
||||
random.seed(1)
|
||||
for prec in [100, 300, 1000, 10000]:
|
||||
mp.dps = prec
|
||||
for i in range(20):
|
||||
A = random.randint(10**(prec//2-2), 10**(prec//2-1))
|
||||
assert sqrt(mpf(A*A)) == A
|
||||
mp.dps = 15
|
||||
for i in range(100):
|
||||
for a in [1, 8, 25, 112307]:
|
||||
assert sqrt(mpf((a*a, 2*i))) == mpf((a, i))
|
||||
assert sqrt(mpf((a*a, -2*i))) == mpf((a, -i))
|
||||
|
||||
def test_sqrt_rounding():
|
||||
for i in [2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15]:
|
||||
i = from_int(i)
|
||||
for dps in [7, 15, 83, 106, 2000]:
|
||||
mp.dps = dps
|
||||
a = mpf_pow_int(mpf_sqrt(i, mp.prec, round_down), 2, mp.prec, round_down)
|
||||
b = mpf_pow_int(mpf_sqrt(i, mp.prec, round_up), 2, mp.prec, round_up)
|
||||
assert mpf_lt(a, i)
|
||||
assert mpf_gt(b, i)
|
||||
random.seed(1234)
|
||||
prec = 100
|
||||
for rnd in [round_down, round_nearest, round_ceiling]:
|
||||
for i in range(100):
|
||||
a = mpf_rand(prec)
|
||||
b = mpf_mul(a, a)
|
||||
assert mpf_sqrt(b, prec, rnd) == a
|
||||
# Test some extreme cases
|
||||
mp.dps = 100
|
||||
a = mpf(9) + 1e-90
|
||||
b = mpf(9) - 1e-90
|
||||
mp.dps = 15
|
||||
assert sqrt(a, rounding='d') == 3
|
||||
assert sqrt(a, rounding='n') == 3
|
||||
assert sqrt(a, rounding='u') > 3
|
||||
assert sqrt(b, rounding='d') < 3
|
||||
assert sqrt(b, rounding='n') == 3
|
||||
assert sqrt(b, rounding='u') == 3
|
||||
# A worst case, from the MPFR test suite
|
||||
assert sqrt(mpf('7.0503726185518891')) == mpf('2.655253776675949')
|
||||
|
||||
def test_float_sqrt():
|
||||
mp.dps = 15
|
||||
# These should round identically
|
||||
for x in [0, 1e-7, 0.1, 0.5, 1, 2, 3, 4, 5, 0.333, 76.19]:
|
||||
assert sqrt(mpf(x)) == float(x)**0.5
|
||||
assert sqrt(-1) == 1j
|
||||
assert sqrt(-2).ae(cmath.sqrt(-2))
|
||||
assert sqrt(-3).ae(cmath.sqrt(-3))
|
||||
assert sqrt(-100).ae(cmath.sqrt(-100))
|
||||
assert sqrt(1j).ae(cmath.sqrt(1j))
|
||||
assert sqrt(-1j).ae(cmath.sqrt(-1j))
|
||||
assert sqrt(math.pi + math.e*1j).ae(cmath.sqrt(math.pi + math.e*1j))
|
||||
assert sqrt(math.pi - math.e*1j).ae(cmath.sqrt(math.pi - math.e*1j))
|
||||
|
||||
def test_hypot():
|
||||
assert hypot(0, 0) == 0
|
||||
assert hypot(0, 0.33) == mpf(0.33)
|
||||
assert hypot(0.33, 0) == mpf(0.33)
|
||||
assert hypot(-0.33, 0) == mpf(0.33)
|
||||
assert hypot(3, 4) == mpf(5)
|
||||
|
||||
def test_exact_cbrt():
|
||||
for i in range(0, 20000, 200):
|
||||
assert cbrt(mpf(i*i*i)) == i
|
||||
random.seed(1)
|
||||
for prec in [100, 300, 1000, 10000]:
|
||||
mp.dps = prec
|
||||
A = random.randint(10**(prec//2-2), 10**(prec//2-1))
|
||||
assert cbrt(mpf(A*A*A)) == A
|
||||
mp.dps = 15
|
||||
|
||||
def test_exp():
|
||||
assert exp(0) == 1
|
||||
assert exp(10000).ae(mpf('8.8068182256629215873e4342'))
|
||||
assert exp(-10000).ae(mpf('1.1354838653147360985e-4343'))
|
||||
a = exp(mpf((1, 8198646019315405L, -53, 53)))
|
||||
assert(a.bc == bitcount(a.man))
|
||||
mp.prec = 67
|
||||
a = exp(mpf((1, 1781864658064754565L, -60, 61)))
|
||||
assert(a.bc == bitcount(a.man))
|
||||
mp.prec = 53
|
||||
assert exp(ln2 * 10).ae(1024)
|
||||
assert exp(2+2j).ae(cmath.exp(2+2j))
|
||||
|
||||
def test_issue_33():
|
||||
mp.dps = 512
|
||||
a = exp(-1)
|
||||
b = exp(1)
|
||||
mp.dps = 15
|
||||
assert (+a).ae(0.36787944117144233)
|
||||
assert (+b).ae(2.7182818284590451)
|
||||
|
||||
def test_log():
|
||||
mp.dps = 15
|
||||
assert log(1) == 0
|
||||
for x in [0.5, 1.5, 2.0, 3.0, 100, 10**50, 1e-50]:
|
||||
assert log(x).ae(math.log(x))
|
||||
assert log(x, x) == 1
|
||||
assert log(1024, 2) == 10
|
||||
assert log(10**1234, 10) == 1234
|
||||
assert log(2+2j).ae(cmath.log(2+2j))
|
||||
# Accuracy near 1
|
||||
assert (log(0.6+0.8j).real*10**17).ae(2.2204460492503131)
|
||||
assert (log(0.6-0.8j).real*10**17).ae(2.2204460492503131)
|
||||
assert (log(0.8-0.6j).real*10**17).ae(2.2204460492503131)
|
||||
assert (log(1+1e-8j).real*10**16).ae(0.5)
|
||||
assert (log(1-1e-8j).real*10**16).ae(0.5)
|
||||
assert (log(-1+1e-8j).real*10**16).ae(0.5)
|
||||
assert (log(-1-1e-8j).real*10**16).ae(0.5)
|
||||
assert (log(1j+1e-8).real*10**16).ae(0.5)
|
||||
assert (log(1j-1e-8).real*10**16).ae(0.5)
|
||||
assert (log(-1j+1e-8).real*10**16).ae(0.5)
|
||||
assert (log(-1j-1e-8).real*10**16).ae(0.5)
|
||||
assert (log(1+1e-40j).real*10**80).ae(0.5)
|
||||
assert (log(1j+1e-40).real*10**80).ae(0.5)
|
||||
# Huge
|
||||
assert log(ldexp(1.234,10**20)).ae(log(2)*1e20)
|
||||
assert log(ldexp(1.234,10**200)).ae(log(2)*1e200)
|
||||
# Some special values
|
||||
assert log(mpc(0,0)) == mpc(-inf,0)
|
||||
assert isnan(log(mpc(nan,0)).real)
|
||||
assert isnan(log(mpc(nan,0)).imag)
|
||||
assert isnan(log(mpc(0,nan)).real)
|
||||
assert isnan(log(mpc(0,nan)).imag)
|
||||
assert isnan(log(mpc(nan,1)).real)
|
||||
assert isnan(log(mpc(nan,1)).imag)
|
||||
assert isnan(log(mpc(1,nan)).real)
|
||||
assert isnan(log(mpc(1,nan)).imag)
|
||||
|
||||
def test_trig_hyperb_basic():
|
||||
for x in (range(100) + range(-100,0)):
|
||||
t = x / 4.1
|
||||
assert cos(mpf(t)).ae(math.cos(t))
|
||||
assert sin(mpf(t)).ae(math.sin(t))
|
||||
assert tan(mpf(t)).ae(math.tan(t))
|
||||
assert cosh(mpf(t)).ae(math.cosh(t))
|
||||
assert sinh(mpf(t)).ae(math.sinh(t))
|
||||
assert tanh(mpf(t)).ae(math.tanh(t))
|
||||
assert sin(1+1j).ae(cmath.sin(1+1j))
|
||||
assert sin(-4-3.6j).ae(cmath.sin(-4-3.6j))
|
||||
assert cos(1+1j).ae(cmath.cos(1+1j))
|
||||
assert cos(-4-3.6j).ae(cmath.cos(-4-3.6j))
|
||||
|
||||
def test_degrees():
|
||||
assert cos(0*degree) == 1
|
||||
assert cos(90*degree).ae(0)
|
||||
assert cos(180*degree).ae(-1)
|
||||
assert cos(270*degree).ae(0)
|
||||
assert cos(360*degree).ae(1)
|
||||
assert sin(0*degree) == 0
|
||||
assert sin(90*degree).ae(1)
|
||||
assert sin(180*degree).ae(0)
|
||||
assert sin(270*degree).ae(-1)
|
||||
assert sin(360*degree).ae(0)
|
||||
|
||||
def random_complexes(N):
|
||||
random.seed(1)
|
||||
a = []
|
||||
for i in range(N):
|
||||
x1 = random.uniform(-10, 10)
|
||||
y1 = random.uniform(-10, 10)
|
||||
x2 = random.uniform(-10, 10)
|
||||
y2 = random.uniform(-10, 10)
|
||||
z1 = complex(x1, y1)
|
||||
z2 = complex(x2, y2)
|
||||
a.append((z1, z2))
|
||||
return a
|
||||
|
||||
def test_complex_powers():
|
||||
for dps in [15, 30, 100]:
|
||||
# Check accuracy for complex square root
|
||||
mp.dps = dps
|
||||
a = mpc(1j)**0.5
|
||||
assert a.real == a.imag == mpf(2)**0.5 / 2
|
||||
mp.dps = 15
|
||||
random.seed(1)
|
||||
for (z1, z2) in random_complexes(100):
|
||||
assert (mpc(z1)**mpc(z2)).ae(z1**z2, 1e-12)
|
||||
assert (e**(-pi*1j)).ae(-1)
|
||||
mp.dps = 50
|
||||
assert (e**(-pi*1j)).ae(-1)
|
||||
mp.dps = 15
|
||||
|
||||
def test_complex_sqrt_accuracy():
|
||||
def test_mpc_sqrt(lst):
|
||||
for a, b in lst:
|
||||
z = mpc(a + j*b)
|
||||
assert mpc_ae(sqrt(z*z), z)
|
||||
z = mpc(-a + j*b)
|
||||
assert mpc_ae(sqrt(z*z), -z)
|
||||
z = mpc(a - j*b)
|
||||
assert mpc_ae(sqrt(z*z), z)
|
||||
z = mpc(-a - j*b)
|
||||
assert mpc_ae(sqrt(z*z), -z)
|
||||
random.seed(2)
|
||||
N = 10
|
||||
mp.dps = 30
|
||||
dps = mp.dps
|
||||
test_mpc_sqrt([(random.uniform(0, 10),random.uniform(0, 10)) for i in range(N)])
|
||||
test_mpc_sqrt([(i + 0.1, (i + 0.2)*10**i) for i in range(N)])
|
||||
mp.dps = 15
|
||||
|
||||
def test_atan():
|
||||
mp.dps = 15
|
||||
assert atan(-2.3).ae(math.atan(-2.3))
|
||||
assert atan(1e-50) == 1e-50
|
||||
assert atan(1e50).ae(pi/2)
|
||||
assert atan(-1e-50) == -1e-50
|
||||
assert atan(-1e50).ae(-pi/2)
|
||||
assert atan(10**1000).ae(pi/2)
|
||||
for dps in [25, 70, 100, 300, 1000]:
|
||||
mp.dps = dps
|
||||
assert (4*atan(1)).ae(pi)
|
||||
mp.dps = 15
|
||||
pi2 = pi/2
|
||||
assert atan(mpc(inf,-1)).ae(pi2)
|
||||
assert atan(mpc(inf,0)).ae(pi2)
|
||||
assert atan(mpc(inf,1)).ae(pi2)
|
||||
assert atan(mpc(1,inf)).ae(pi2)
|
||||
assert atan(mpc(0,inf)).ae(pi2)
|
||||
assert atan(mpc(-1,inf)).ae(-pi2)
|
||||
assert atan(mpc(-inf,1)).ae(-pi2)
|
||||
assert atan(mpc(-inf,0)).ae(-pi2)
|
||||
assert atan(mpc(-inf,-1)).ae(-pi2)
|
||||
assert atan(mpc(-1,-inf)).ae(-pi2)
|
||||
assert atan(mpc(0,-inf)).ae(-pi2)
|
||||
assert atan(mpc(1,-inf)).ae(pi2)
|
||||
|
||||
def test_atan2():
|
||||
mp.dps = 15
|
||||
assert atan2(1,1).ae(pi/4)
|
||||
assert atan2(1,-1).ae(3*pi/4)
|
||||
assert atan2(-1,-1).ae(-3*pi/4)
|
||||
assert atan2(-1,1).ae(-pi/4)
|
||||
assert atan2(-1,0).ae(-pi/2)
|
||||
assert atan2(1,0).ae(pi/2)
|
||||
assert atan2(0,0) == 0
|
||||
assert atan2(inf,0).ae(pi/2)
|
||||
assert atan2(-inf,0).ae(-pi/2)
|
||||
assert isnan(atan2(inf,inf))
|
||||
assert isnan(atan2(-inf,inf))
|
||||
assert isnan(atan2(inf,-inf))
|
||||
assert isnan(atan2(3,nan))
|
||||
assert isnan(atan2(nan,3))
|
||||
assert isnan(atan2(0,nan))
|
||||
assert isnan(atan2(nan,0))
|
||||
assert atan2(0,inf) == 0
|
||||
assert atan2(0,-inf).ae(pi)
|
||||
assert atan2(10,inf) == 0
|
||||
assert atan2(-10,inf) == 0
|
||||
assert atan2(-10,-inf).ae(-pi)
|
||||
assert atan2(10,-inf).ae(pi)
|
||||
assert atan2(inf,10).ae(pi/2)
|
||||
assert atan2(inf,-10).ae(pi/2)
|
||||
assert atan2(-inf,10).ae(-pi/2)
|
||||
assert atan2(-inf,-10).ae(-pi/2)
|
||||
|
||||
def test_areal_inverses():
|
||||
assert asin(mpf(0)) == 0
|
||||
assert asinh(mpf(0)) == 0
|
||||
assert acosh(mpf(1)) == 0
|
||||
assert isinstance(asin(mpf(0.5)), mpf)
|
||||
assert isinstance(asin(mpf(2.0)), mpc)
|
||||
assert isinstance(acos(mpf(0.5)), mpf)
|
||||
assert isinstance(acos(mpf(2.0)), mpc)
|
||||
assert isinstance(atanh(mpf(0.1)), mpf)
|
||||
assert isinstance(atanh(mpf(1.1)), mpc)
|
||||
|
||||
random.seed(1)
|
||||
for i in range(50):
|
||||
x = random.uniform(0, 1)
|
||||
assert asin(mpf(x)).ae(math.asin(x))
|
||||
assert acos(mpf(x)).ae(math.acos(x))
|
||||
|
||||
x = random.uniform(-10, 10)
|
||||
assert asinh(mpf(x)).ae(cmath.asinh(x).real)
|
||||
assert isinstance(asinh(mpf(x)), mpf)
|
||||
x = random.uniform(1, 10)
|
||||
assert acosh(mpf(x)).ae(cmath.acosh(x).real)
|
||||
assert isinstance(acosh(mpf(x)), mpf)
|
||||
x = random.uniform(-10, 0.999)
|
||||
assert isinstance(acosh(mpf(x)), mpc)
|
||||
|
||||
x = random.uniform(-1, 1)
|
||||
assert atanh(mpf(x)).ae(cmath.atanh(x).real)
|
||||
assert isinstance(atanh(mpf(x)), mpf)
|
||||
|
||||
dps = mp.dps
|
||||
mp.dps = 300
|
||||
assert isinstance(asin(0.5), mpf)
|
||||
mp.dps = 1000
|
||||
assert asin(1).ae(pi/2)
|
||||
assert asin(-1).ae(-pi/2)
|
||||
mp.dps = dps
|
||||
|
||||
def test_invhyperb_inaccuracy():
|
||||
mp.dps = 15
|
||||
assert (asinh(1e-5)*10**5).ae(0.99999999998333333)
|
||||
assert (asinh(1e-10)*10**10).ae(1)
|
||||
assert (asinh(1e-50)*10**50).ae(1)
|
||||
assert (asinh(-1e-5)*10**5).ae(-0.99999999998333333)
|
||||
assert (asinh(-1e-10)*10**10).ae(-1)
|
||||
assert (asinh(-1e-50)*10**50).ae(-1)
|
||||
assert asinh(10**20).ae(46.744849040440862)
|
||||
assert asinh(-10**20).ae(-46.744849040440862)
|
||||
assert (tanh(1e-10)*10**10).ae(1)
|
||||
assert (tanh(-1e-10)*10**10).ae(-1)
|
||||
assert (atanh(1e-10)*10**10).ae(1)
|
||||
assert (atanh(-1e-10)*10**10).ae(-1)
|
||||
|
||||
def test_complex_functions():
|
||||
for x in (range(10) + range(-10,0)):
|
||||
for y in (range(10) + range(-10,0)):
|
||||
z = complex(x, y)/4.3 + 0.01j
|
||||
assert exp(mpc(z)).ae(cmath.exp(z))
|
||||
assert log(mpc(z)).ae(cmath.log(z))
|
||||
assert cos(mpc(z)).ae(cmath.cos(z))
|
||||
assert sin(mpc(z)).ae(cmath.sin(z))
|
||||
assert tan(mpc(z)).ae(cmath.tan(z))
|
||||
assert sinh(mpc(z)).ae(cmath.sinh(z))
|
||||
assert cosh(mpc(z)).ae(cmath.cosh(z))
|
||||
assert tanh(mpc(z)).ae(cmath.tanh(z))
|
||||
|
||||
def test_complex_inverse_functions():
|
||||
for (z1, z2) in random_complexes(30):
|
||||
# apparently cmath uses a different branch, so we
|
||||
# can't use it for comparison
|
||||
assert sinh(asinh(z1)).ae(z1)
|
||||
#
|
||||
assert acosh(z1).ae(cmath.acosh(z1))
|
||||
assert atanh(z1).ae(cmath.atanh(z1))
|
||||
assert atan(z1).ae(cmath.atan(z1))
|
||||
# the reason we set a big eps here is that the cmath
|
||||
# functions are inaccurate
|
||||
assert asin(z1).ae(cmath.asin(z1), rel_eps=1e-12)
|
||||
assert acos(z1).ae(cmath.acos(z1), rel_eps=1e-12)
|
||||
one = mpf(1)
|
||||
for i in range(-9, 10, 3):
|
||||
for k in range(-9, 10, 3):
|
||||
a = 0.9*j*10**k + 0.8*one*10**i
|
||||
b = cos(acos(a))
|
||||
assert b.ae(a)
|
||||
b = sin(asin(a))
|
||||
assert b.ae(a)
|
||||
one = mpf(1)
|
||||
err = 2*10**-15
|
||||
for i in range(-9, 9, 3):
|
||||
for k in range(-9, 9, 3):
|
||||
a = -0.9*10**k + j*0.8*one*10**i
|
||||
b = cosh(acosh(a))
|
||||
assert b.ae(a, err)
|
||||
b = sinh(asinh(a))
|
||||
assert b.ae(a, err)
|
||||
|
||||
def test_reciprocal_functions():
|
||||
assert sec(3).ae(-1.01010866590799375)
|
||||
assert csc(3).ae(7.08616739573718592)
|
||||
assert cot(3).ae(-7.01525255143453347)
|
||||
assert sech(3).ae(0.0993279274194332078)
|
||||
assert csch(3).ae(0.0998215696688227329)
|
||||
assert coth(3).ae(1.00496982331368917)
|
||||
assert asec(3).ae(1.23095941734077468)
|
||||
assert acsc(3).ae(0.339836909454121937)
|
||||
assert acot(3).ae(0.321750554396642193)
|
||||
assert asech(0.5).ae(1.31695789692481671)
|
||||
assert acsch(3).ae(0.327450150237258443)
|
||||
assert acoth(3).ae(0.346573590279972655)
|
||||
|
||||
def test_ldexp():
|
||||
mp.dps = 15
|
||||
assert ldexp(mpf(2.5), 0) == 2.5
|
||||
assert ldexp(mpf(2.5), -1) == 1.25
|
||||
assert ldexp(mpf(2.5), 2) == 10
|
||||
assert ldexp(mpf('inf'), 3) == mpf('inf')
|
||||
|
||||
def test_frexp():
|
||||
mp.dps = 15
|
||||
assert frexp(0) == (0.0, 0)
|
||||
assert frexp(9) == (0.5625, 4)
|
||||
assert frexp(1) == (0.5, 1)
|
||||
assert frexp(0.2) == (0.8, -2)
|
||||
assert frexp(1000) == (0.9765625, 10)
|
||||
|
||||
def test_aliases():
|
||||
assert ln(7) == log(7)
|
||||
assert log10(3.75) == log(3.75,10)
|
||||
assert degrees(5.6) == 5.6 / degree
|
||||
assert radians(5.6) == 5.6 * degree
|
||||
assert power(-1,0.5) == j
|
||||
assert modf(25,7) == 4.0 and isinstance(modf(25,7), mpf)
|
||||
|
||||
def test_arg_sign():
|
||||
assert arg(3) == 0
|
||||
assert arg(-3).ae(pi)
|
||||
assert arg(j).ae(pi/2)
|
||||
assert arg(-j).ae(-pi/2)
|
||||
assert arg(0) == 0
|
||||
assert isnan(atan2(3,nan))
|
||||
assert isnan(atan2(nan,3))
|
||||
assert isnan(atan2(0,nan))
|
||||
assert isnan(atan2(nan,0))
|
||||
assert isnan(atan2(nan,nan))
|
||||
assert arg(inf) == 0
|
||||
assert arg(-inf).ae(pi)
|
||||
assert isnan(arg(nan))
|
||||
#assert arg(inf*j).ae(pi/2)
|
||||
assert sign(0) == 0
|
||||
assert sign(3) == 1
|
||||
assert sign(-3) == -1
|
||||
assert sign(inf) == 1
|
||||
assert sign(-inf) == -1
|
||||
assert isnan(sign(nan))
|
||||
assert sign(j) == j
|
||||
assert sign(-3*j) == -j
|
||||
assert sign(1+j).ae((1+j)/sqrt(2))
|
||||
|
||||
def test_misc_bugs():
|
||||
# test that this doesn't raise an exception
|
||||
mp.dps = 1000
|
||||
log(1302)
|
||||
mp.dps = 15
|
||||
|
||||
def test_arange():
|
||||
assert arange(10) == [mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0'),
|
||||
mpf('4.0'), mpf('5.0'), mpf('6.0'), mpf('7.0'),
|
||||
mpf('8.0'), mpf('9.0')]
|
||||
assert arange(-5, 5) == [mpf('-5.0'), mpf('-4.0'), mpf('-3.0'),
|
||||
mpf('-2.0'), mpf('-1.0'), mpf('0.0'),
|
||||
mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
|
||||
assert arange(0, 1, 0.1) == [mpf('0.0'), mpf('0.10000000000000001'),
|
||||
mpf('0.20000000000000001'),
|
||||
mpf('0.30000000000000004'),
|
||||
mpf('0.40000000000000002'),
|
||||
mpf('0.5'), mpf('0.60000000000000009'),
|
||||
mpf('0.70000000000000007'),
|
||||
mpf('0.80000000000000004'),
|
||||
mpf('0.90000000000000002')]
|
||||
assert arange(17, -9, -3) == [mpf('17.0'), mpf('14.0'), mpf('11.0'),
|
||||
mpf('8.0'), mpf('5.0'), mpf('2.0'),
|
||||
mpf('-1.0'), mpf('-4.0'), mpf('-7.0')]
|
||||
assert arange(0.2, 0.1, -0.1) == [mpf('0.20000000000000001')]
|
||||
assert arange(0) == []
|
||||
assert arange(1000, -1) == []
|
||||
assert arange(-1.23, 3.21, -0.0000001) == []
|
||||
|
||||
def test_linspace():
|
||||
assert linspace(2, 9, 7) == [mpf('2.0'), mpf('3.166666666666667'),
|
||||
mpf('4.3333333333333339'), mpf('5.5'), mpf('6.666666666666667'),
|
||||
mpf('7.8333333333333339'), mpf('9.0')] == linspace(mpi(2, 9), 7)
|
||||
assert linspace(2, 9, 7, endpoint=0) == [mpf('2.0'), mpf('3.0'), mpf('4.0'),
|
||||
mpf('5.0'), mpf('6.0'), mpf('7.0'), mpf('8.0')]
|
||||
assert linspace(2, 7, 1) == [mpf(2)]
|
||||
|
||||
def test_float_cbrt():
|
||||
mp.dps = 30
|
||||
for a in arange(0,10,0.1):
|
||||
assert cbrt(a*a*a).ae(a, eps)
|
||||
assert cbrt(-1).ae(0.5 + j*sqrt(3)/2)
|
||||
one_third = mpf(1)/3
|
||||
for a in arange(0,10,2.7) + [0.1 + 10**5]:
|
||||
a = mpc(a + 1.1j)
|
||||
r1 = cbrt(a)
|
||||
mp.dps += 10
|
||||
r2 = pow(a, one_third)
|
||||
mp.dps -= 10
|
||||
assert r1.ae(r2, eps)
|
||||
mp.dps = 100
|
||||
for n in range(100, 301, 100):
|
||||
w = 10**n + j*10**-3
|
||||
z = w*w*w
|
||||
r = cbrt(z)
|
||||
assert mpc_ae(r, w, eps)
|
||||
mp.dps = 15
|
||||
|
||||
def test_root():
|
||||
mp.dps = 30
|
||||
random.seed(1)
|
||||
a = random.randint(0, 10000)
|
||||
p = a*a*a
|
||||
r = nthroot(mpf(p), 3)
|
||||
assert r == a
|
||||
for n in range(4, 10):
|
||||
p = p*a
|
||||
assert nthroot(mpf(p), n) == a
|
||||
mp.dps = 40
|
||||
for n in range(10, 5000, 100):
|
||||
for a in [random.random()*10000, random.random()*10**100]:
|
||||
r = nthroot(a, n)
|
||||
r1 = pow(a, mpf(1)/n)
|
||||
assert r.ae(r1)
|
||||
r = nthroot(a, -n)
|
||||
r1 = pow(a, -mpf(1)/n)
|
||||
assert r.ae(r1)
|
||||
# XXX: this is broken right now
|
||||
# tests for nthroot rounding
|
||||
for rnd in ['nearest', 'up', 'down']:
|
||||
mp.rounding = rnd
|
||||
for n in [-5, -3, 3, 5]:
|
||||
prec = 50
|
||||
for i in xrange(10):
|
||||
mp.prec = prec
|
||||
a = rand()
|
||||
mp.prec = 2*prec
|
||||
b = a**n
|
||||
mp.prec = prec
|
||||
r = nthroot(b, n)
|
||||
assert r == a
|
||||
mp.dps = 30
|
||||
for n in range(3, 21):
|
||||
a = (random.random() + j*random.random())
|
||||
assert nthroot(a, n).ae(pow(a, mpf(1)/n))
|
||||
assert mpc_ae(nthroot(a, n), pow(a, mpf(1)/n))
|
||||
a = (random.random()*10**100 + j*random.random())
|
||||
r = nthroot(a, n)
|
||||
mp.dps += 4
|
||||
r1 = pow(a, mpf(1)/n)
|
||||
mp.dps -= 4
|
||||
assert r.ae(r1)
|
||||
assert mpc_ae(r, r1, eps)
|
||||
r = nthroot(a, -n)
|
||||
mp.dps += 4
|
||||
r1 = pow(a, -mpf(1)/n)
|
||||
mp.dps -= 4
|
||||
assert r.ae(r1)
|
||||
assert mpc_ae(r, r1, eps)
|
||||
mp.dps = 15
|
||||
assert nthroot(4, 1) == 4
|
||||
assert nthroot(4, 0) == 1
|
||||
assert nthroot(4, -1) == 0.25
|
||||
assert nthroot(inf, 1) == inf
|
||||
assert nthroot(inf, 2) == inf
|
||||
assert nthroot(inf, 3) == inf
|
||||
assert nthroot(inf, -1) == 0
|
||||
assert nthroot(inf, -2) == 0
|
||||
assert nthroot(inf, -3) == 0
|
||||
assert nthroot(j, 1) == j
|
||||
assert nthroot(j, 0) == 1
|
||||
assert nthroot(j, -1) == -j
|
||||
assert isnan(nthroot(nan, 1))
|
||||
assert isnan(nthroot(nan, 0))
|
||||
assert isnan(nthroot(nan, -1))
|
||||
assert isnan(nthroot(inf, 0))
|
||||
assert root(2,3) == nthroot(2,3)
|
||||
assert root(16,4,0) == 2
|
||||
assert root(16,4,1) == 2j
|
||||
assert root(16,4,2) == -2
|
||||
assert root(16,4,3) == -2j
|
||||
assert root(16,4,4) == 2
|
||||
assert root(-125,3,1) == -5
|
||||
|
||||
def test_issue_96():
|
||||
for dps in [20, 80]:
|
||||
mp.dps = dps
|
||||
r = nthroot(mpf('-1e-20'), 4)
|
||||
assert r.ae(mpf(10)**(-5) * (1 + j) * mpf(2)**(-0.5))
|
||||
mp.dps = 80
|
||||
assert nthroot('-1e-3', 4).ae(mpf(10)**(-3./4) * (1 + j)/sqrt(2))
|
||||
assert nthroot('-1e-6', 4).ae((1 + j)/(10 * sqrt(20)))
|
||||
# Check that this doesn't take eternity to compute
|
||||
mp.dps = 20
|
||||
assert nthroot('-1e100000000', 4).ae((1+j)*mpf('1e25000000')/sqrt(2))
|
||||
mp.dps = 15
|
||||
|
||||
def test_perturbation_rounding():
|
||||
mp.dps = 100
|
||||
a = pi/10**50
|
||||
b = -pi/10**50
|
||||
c = 1 + a
|
||||
d = 1 + b
|
||||
mp.dps = 15
|
||||
assert exp(a) == 1
|
||||
assert exp(a, rounding='c') > 1
|
||||
assert exp(b, rounding='c') == 1
|
||||
assert exp(a, rounding='f') == 1
|
||||
assert exp(b, rounding='f') < 1
|
||||
assert cos(a) == 1
|
||||
assert cos(a, rounding='c') == 1
|
||||
assert cos(b, rounding='c') == 1
|
||||
assert cos(a, rounding='f') < 1
|
||||
assert cos(b, rounding='f') < 1
|
||||
for f in [sin, atan, asinh, tanh]:
|
||||
assert f(a) == +a
|
||||
assert f(a, rounding='c') > a
|
||||
assert f(a, rounding='f') < a
|
||||
assert f(b) == +b
|
||||
assert f(b, rounding='c') > b
|
||||
assert f(b, rounding='f') < b
|
||||
for f in [asin, tan, sinh, atanh]:
|
||||
assert f(a) == +a
|
||||
assert f(b) == +b
|
||||
assert f(a, rounding='c') > a
|
||||
assert f(b, rounding='c') > b
|
||||
assert f(a, rounding='f') < a
|
||||
assert f(b, rounding='f') < b
|
||||
assert ln(c) == +a
|
||||
assert ln(d) == +b
|
||||
assert ln(c, rounding='c') > a
|
||||
assert ln(c, rounding='f') < a
|
||||
assert ln(d, rounding='c') > b
|
||||
assert ln(d, rounding='f') < b
|
||||
assert cosh(a) == 1
|
||||
assert cosh(b) == 1
|
||||
assert cosh(a, rounding='c') > 1
|
||||
assert cosh(b, rounding='c') > 1
|
||||
assert cosh(a, rounding='f') == 1
|
||||
assert cosh(b, rounding='f') == 1
|
||||
|
||||
def test_integer_parts():
|
||||
assert floor(3.2) == 3
|
||||
assert ceil(3.2) == 4
|
||||
assert floor(3.2+5j) == 3+5j
|
||||
assert ceil(3.2+5j) == 4+5j
|
||||
|
||||
def test_complex_parts():
|
||||
assert fabs('3') == 3
|
||||
assert fabs(3+4j) == 5
|
||||
assert re(3) == 3
|
||||
assert re(1+4j) == 1
|
||||
assert im(3) == 0
|
||||
assert im(1+4j) == 4
|
||||
assert conj(3) == 3
|
||||
assert conj(3+4j) == 3-4j
|
||||
assert mpf(3).conjugate() == 3
|
||||
|
||||
def test_cospi_sinpi():
|
||||
assert sinpi(0) == 0
|
||||
assert sinpi(0.5) == 1
|
||||
assert sinpi(1) == 0
|
||||
assert sinpi(1.5) == -1
|
||||
assert sinpi(2) == 0
|
||||
assert sinpi(2.5) == 1
|
||||
assert sinpi(-0.5) == -1
|
||||
assert cospi(0) == 1
|
||||
assert cospi(0.5) == 0
|
||||
assert cospi(1) == -1
|
||||
assert cospi(1.5) == 0
|
||||
assert cospi(2) == 1
|
||||
assert cospi(2.5) == 0
|
||||
assert cospi(-0.5) == 0
|
||||
assert cospi(100000000000.25).ae(sqrt(2)/2)
|
||||
a = cospi(2+3j)
|
||||
assert a.real.ae(cos((2+3j)*pi).real)
|
||||
assert a.imag == 0
|
||||
b = sinpi(2+3j)
|
||||
assert b.imag.ae(sin((2+3j)*pi).imag)
|
||||
assert b.real == 0
|
||||
mp.dps = 35
|
||||
x1 = mpf(10000) - mpf('1e-15')
|
||||
x2 = mpf(10000) + mpf('1e-15')
|
||||
x3 = mpf(10000.5) - mpf('1e-15')
|
||||
x4 = mpf(10000.5) + mpf('1e-15')
|
||||
x5 = mpf(10001) - mpf('1e-15')
|
||||
x6 = mpf(10001) + mpf('1e-15')
|
||||
x7 = mpf(10001.5) - mpf('1e-15')
|
||||
x8 = mpf(10001.5) + mpf('1e-15')
|
||||
mp.dps = 15
|
||||
M = 10**15
|
||||
assert (sinpi(x1)*M).ae(-pi)
|
||||
assert (sinpi(x2)*M).ae(pi)
|
||||
assert (cospi(x3)*M).ae(pi)
|
||||
assert (cospi(x4)*M).ae(-pi)
|
||||
assert (sinpi(x5)*M).ae(pi)
|
||||
assert (sinpi(x6)*M).ae(-pi)
|
||||
assert (cospi(x7)*M).ae(-pi)
|
||||
assert (cospi(x8)*M).ae(pi)
|
||||
assert 0.999 < cospi(x1, rounding='d') < 1
|
||||
assert 0.999 < cospi(x2, rounding='d') < 1
|
||||
assert 0.999 < sinpi(x3, rounding='d') < 1
|
||||
assert 0.999 < sinpi(x4, rounding='d') < 1
|
||||
assert -1 < cospi(x5, rounding='d') < -0.999
|
||||
assert -1 < cospi(x6, rounding='d') < -0.999
|
||||
assert -1 < sinpi(x7, rounding='d') < -0.999
|
||||
assert -1 < sinpi(x8, rounding='d') < -0.999
|
||||
assert (sinpi(1e-15)*M).ae(pi)
|
||||
assert (sinpi(-1e-15)*M).ae(-pi)
|
||||
assert cospi(1e-15) == 1
|
||||
assert cospi(1e-15, rounding='d') < 1
|
||||
|
||||
def test_expj():
|
||||
assert expj(0) == 1
|
||||
assert expj(1).ae(exp(j))
|
||||
assert expj(j).ae(exp(-1))
|
||||
assert expj(1+j).ae(exp(j*(1+j)))
|
||||
assert expjpi(0) == 1
|
||||
assert expjpi(1).ae(exp(j*pi))
|
||||
assert expjpi(j).ae(exp(-pi))
|
||||
assert expjpi(1+j).ae(exp(j*pi*(1+j)))
|
||||
assert expjpi(-10**15 * j).ae('2.22579818340535731e+1364376353841841')
|
||||
|
||||
def test_sinc():
|
||||
assert sinc(0) == sincpi(0) == 1
|
||||
assert sinc(inf) == sincpi(inf) == 0
|
||||
assert sinc(-inf) == sincpi(-inf) == 0
|
||||
assert sinc(2).ae(0.45464871341284084770)
|
||||
assert sinc(2+3j).ae(0.4463290318402435457-2.7539470277436474940j)
|
||||
assert sincpi(2) == 0
|
||||
assert sincpi(1.5).ae(-0.212206590789193781)
|
||||
|
||||
def test_fibonacci():
|
||||
mp.dps = 15
|
||||
assert [fibonacci(n) for n in range(-5, 10)] == \
|
||||
[5, -3, 2, -1, 1, 0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
|
||||
assert fib(2.5).ae(1.4893065462657091)
|
||||
assert fib(3+4j).ae(-5248.51130728372 - 14195.962288353j)
|
||||
assert fib(1000).ae(4.3466557686937455e+208)
|
||||
assert str(fib(10**100)) == '6.24499112864607e+2089876402499787337692720892375554168224592399182109535392875613974104853496745963277658556235103534'
|
||||
mp.dps = 2100
|
||||
a = fib(10000)
|
||||
assert a % 10**10 == 9947366875
|
||||
mp.dps = 15
|
||||
assert fibonacci(inf) == inf
|
||||
assert fib(3+0j) == 2
|
||||
|
||||
def test_call_with_dps():
|
||||
mp.dps = 15
|
||||
assert abs(exp(1, dps=30)-e(dps=35)) < 1e-29
|
||||
|
||||
def test_tanh():
|
||||
mp.dps = 15
|
||||
assert tanh(0) == 0
|
||||
assert tanh(inf) == 1
|
||||
assert tanh(-inf) == -1
|
||||
assert isnan(tanh(nan))
|
||||
assert tanh(mpc('inf', '0')) == 1
|
||||
|
||||
def test_atanh():
|
||||
mp.dps = 15
|
||||
assert atanh(0) == 0
|
||||
assert atanh(0.5).ae(0.54930614433405484570)
|
||||
assert atanh(-0.5).ae(-0.54930614433405484570)
|
||||
assert atanh(1) == inf
|
||||
assert atanh(-1) == -inf
|
||||
assert isnan(atanh(nan))
|
||||
assert isinstance(atanh(1), mpf)
|
||||
assert isinstance(atanh(-1), mpf)
|
||||
# Limits at infinity
|
||||
jpi2 = j*pi/2
|
||||
assert atanh(inf).ae(-jpi2)
|
||||
assert atanh(-inf).ae(jpi2)
|
||||
assert atanh(mpc(inf,-1)).ae(-jpi2)
|
||||
assert atanh(mpc(inf,0)).ae(-jpi2)
|
||||
assert atanh(mpc(inf,1)).ae(jpi2)
|
||||
assert atanh(mpc(1,inf)).ae(jpi2)
|
||||
assert atanh(mpc(0,inf)).ae(jpi2)
|
||||
assert atanh(mpc(-1,inf)).ae(jpi2)
|
||||
assert atanh(mpc(-inf,1)).ae(jpi2)
|
||||
assert atanh(mpc(-inf,0)).ae(jpi2)
|
||||
assert atanh(mpc(-inf,-1)).ae(-jpi2)
|
||||
assert atanh(mpc(-1,-inf)).ae(-jpi2)
|
||||
assert atanh(mpc(0,-inf)).ae(-jpi2)
|
||||
assert atanh(mpc(1,-inf)).ae(-jpi2)
|
||||
|
||||
def test_expm1():
|
||||
mp.dps = 15
|
||||
assert expm1(0) == 0
|
||||
assert expm1(3).ae(exp(3)-1)
|
||||
assert expm1(inf) == inf
|
||||
assert expm1(1e-10)*1e10
|
||||
assert expm1(1e-50).ae(1e-50)
|
||||
assert (expm1(1e-10)*1e10).ae(1.00000000005)
|
||||
|
||||
def test_powm1():
|
||||
mp.dps = 15
|
||||
assert powm1(2,3) == 7
|
||||
assert powm1(-1,2) == 0
|
||||
assert powm1(-1,0) == 0
|
||||
assert powm1(-2,0) == 0
|
||||
assert powm1(3+4j,0) == 0
|
||||
assert powm1(0,1) == -1
|
||||
assert powm1(0,0) == 0
|
||||
assert powm1(1,0) == 0
|
||||
assert powm1(1,2) == 0
|
||||
assert powm1(1,3+4j) == 0
|
||||
assert powm1(1,5) == 0
|
||||
assert powm1(j,4) == 0
|
||||
assert powm1(-j,4) == 0
|
||||
assert (powm1(2,1e-100)*1e100).ae(ln2)
|
||||
assert powm1(2,'1e-100000000000') != 0
|
||||
assert (powm1(fadd(1,1e-100,exact=True), 5)*1e100).ae(5)
|
||||
|
||||
def test_unitroots():
|
||||
assert unitroots(1) == [1]
|
||||
assert unitroots(2) == [1, -1]
|
||||
a, b, c = unitroots(3)
|
||||
assert a == 1
|
||||
assert b.ae(-0.5 + 0.86602540378443864676j)
|
||||
assert c.ae(-0.5 - 0.86602540378443864676j)
|
||||
assert unitroots(1, primitive=True) == [1]
|
||||
assert unitroots(2, primitive=True) == [-1]
|
||||
assert unitroots(3, primitive=True) == unitroots(3)[1:]
|
||||
assert unitroots(4, primitive=True) == [j, -j]
|
||||
assert len(unitroots(17, primitive=True)) == 16
|
||||
assert len(unitroots(16, primitive=True)) == 8
|
||||
|
||||
def test_cyclotomic():
|
||||
mp.dps = 15
|
||||
assert [cyclotomic(n,1) for n in range(31)] == [1,0,2,3,2,5,1,7,2,3,1,11,1,13,1,1,2,17,1,19,1,1,1,23,1,5,1,3,1,29,1]
|
||||
assert [cyclotomic(n,-1) for n in range(31)] == [1,-2,0,1,2,1,3,1,2,1,5,1,1,1,7,1,2,1,3,1,1,1,11,1,1,1,13,1,1,1,1]
|
||||
assert [cyclotomic(n,j) for n in range(21)] == [1,-1+j,1+j,j,0,1,-j,j,2,-j,1,j,3,1,-j,1,2,1,j,j,5]
|
||||
assert [cyclotomic(n,-j) for n in range(21)] == [1,-1-j,1-j,-j,0,1,j,-j,2,j,1,-j,3,1,j,1,2,1,-j,-j,5]
|
||||
assert cyclotomic(1624,j) == 1
|
||||
assert cyclotomic(33600,j) == 1
|
||||
u = sqrt(j, prec=500)
|
||||
assert cyclotomic(8, u).ae(0)
|
||||
assert cyclotomic(30, u).ae(5.8284271247461900976)
|
||||
assert cyclotomic(2040, u).ae(1)
|
||||
assert cyclotomic(0,2.5) == 1
|
||||
assert cyclotomic(1,2.5) == 2.5-1
|
||||
assert cyclotomic(2,2.5) == 2.5+1
|
||||
assert cyclotomic(3,2.5) == 2.5**2 + 2.5 + 1
|
||||
assert cyclotomic(7,2.5) == 406.234375
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,658 +0,0 @@
|
|||
from mpmath import *
|
||||
from mpmath.libmp import round_up, from_float, mpf_zeta_int
|
||||
|
||||
def test_zeta_int_bug():
|
||||
assert mpf_zeta_int(0, 10) == from_float(-0.5)
|
||||
|
||||
def test_bernoulli():
|
||||
assert bernfrac(0) == (1,1)
|
||||
assert bernfrac(1) == (-1,2)
|
||||
assert bernfrac(2) == (1,6)
|
||||
assert bernfrac(3) == (0,1)
|
||||
assert bernfrac(4) == (-1,30)
|
||||
assert bernfrac(5) == (0,1)
|
||||
assert bernfrac(6) == (1,42)
|
||||
assert bernfrac(8) == (-1,30)
|
||||
assert bernfrac(10) == (5,66)
|
||||
assert bernfrac(12) == (-691,2730)
|
||||
assert bernfrac(18) == (43867,798)
|
||||
p, q = bernfrac(228)
|
||||
assert p % 10**10 == 164918161
|
||||
assert q == 625170
|
||||
p, q = bernfrac(1000)
|
||||
assert p % 10**10 == 7950421099
|
||||
assert q == 342999030
|
||||
mp.dps = 15
|
||||
assert bernoulli(0) == 1
|
||||
assert bernoulli(1) == -0.5
|
||||
assert bernoulli(2).ae(1./6)
|
||||
assert bernoulli(3) == 0
|
||||
assert bernoulli(4).ae(-1./30)
|
||||
assert bernoulli(5) == 0
|
||||
assert bernoulli(6).ae(1./42)
|
||||
assert str(bernoulli(10)) == '0.0757575757575758'
|
||||
assert str(bernoulli(234)) == '7.62772793964344e+267'
|
||||
assert str(bernoulli(10**5)) == '-5.82229431461335e+376755'
|
||||
assert str(bernoulli(10**8+2)) == '1.19570355039953e+676752584'
|
||||
|
||||
mp.dps = 50
|
||||
assert str(bernoulli(10)) == '0.075757575757575757575757575757575757575757575757576'
|
||||
assert str(bernoulli(234)) == '7.6277279396434392486994969020496121553385863373331e+267'
|
||||
assert str(bernoulli(10**5)) == '-5.8222943146133508236497045360612887555320691004308e+376755'
|
||||
assert str(bernoulli(10**8+2)) == '1.1957035503995297272263047884604346914602088317782e+676752584'
|
||||
|
||||
mp.dps = 1000
|
||||
assert bernoulli(10).ae(mpf(5)/66)
|
||||
|
||||
mp.dps = 50000
|
||||
assert bernoulli(10).ae(mpf(5)/66)
|
||||
|
||||
mp.dps = 15
|
||||
|
||||
def test_bernpoly_eulerpoly():
|
||||
mp.dps = 15
|
||||
assert bernpoly(0,-1).ae(1)
|
||||
assert bernpoly(0,0).ae(1)
|
||||
assert bernpoly(0,'1/2').ae(1)
|
||||
assert bernpoly(0,'3/4').ae(1)
|
||||
assert bernpoly(0,1).ae(1)
|
||||
assert bernpoly(0,2).ae(1)
|
||||
assert bernpoly(1,-1).ae('-3/2')
|
||||
assert bernpoly(1,0).ae('-1/2')
|
||||
assert bernpoly(1,'1/2').ae(0)
|
||||
assert bernpoly(1,'3/4').ae('1/4')
|
||||
assert bernpoly(1,1).ae('1/2')
|
||||
assert bernpoly(1,2).ae('3/2')
|
||||
assert bernpoly(2,-1).ae('13/6')
|
||||
assert bernpoly(2,0).ae('1/6')
|
||||
assert bernpoly(2,'1/2').ae('-1/12')
|
||||
assert bernpoly(2,'3/4').ae('-1/48')
|
||||
assert bernpoly(2,1).ae('1/6')
|
||||
assert bernpoly(2,2).ae('13/6')
|
||||
assert bernpoly(3,-1).ae(-3)
|
||||
assert bernpoly(3,0).ae(0)
|
||||
assert bernpoly(3,'1/2').ae(0)
|
||||
assert bernpoly(3,'3/4').ae('-3/64')
|
||||
assert bernpoly(3,1).ae(0)
|
||||
assert bernpoly(3,2).ae(3)
|
||||
assert bernpoly(4,-1).ae('119/30')
|
||||
assert bernpoly(4,0).ae('-1/30')
|
||||
assert bernpoly(4,'1/2').ae('7/240')
|
||||
assert bernpoly(4,'3/4').ae('7/3840')
|
||||
assert bernpoly(4,1).ae('-1/30')
|
||||
assert bernpoly(4,2).ae('119/30')
|
||||
assert bernpoly(5,-1).ae(-5)
|
||||
assert bernpoly(5,0).ae(0)
|
||||
assert bernpoly(5,'1/2').ae(0)
|
||||
assert bernpoly(5,'3/4').ae('25/1024')
|
||||
assert bernpoly(5,1).ae(0)
|
||||
assert bernpoly(5,2).ae(5)
|
||||
assert bernpoly(10,-1).ae('665/66')
|
||||
assert bernpoly(10,0).ae('5/66')
|
||||
assert bernpoly(10,'1/2').ae('-2555/33792')
|
||||
assert bernpoly(10,'3/4').ae('-2555/34603008')
|
||||
assert bernpoly(10,1).ae('5/66')
|
||||
assert bernpoly(10,2).ae('665/66')
|
||||
assert bernpoly(11,-1).ae(-11)
|
||||
assert bernpoly(11,0).ae(0)
|
||||
assert bernpoly(11,'1/2').ae(0)
|
||||
assert bernpoly(11,'3/4').ae('-555731/4194304')
|
||||
assert bernpoly(11,1).ae(0)
|
||||
assert bernpoly(11,2).ae(11)
|
||||
assert eulerpoly(0,-1).ae(1)
|
||||
assert eulerpoly(0,0).ae(1)
|
||||
assert eulerpoly(0,'1/2').ae(1)
|
||||
assert eulerpoly(0,'3/4').ae(1)
|
||||
assert eulerpoly(0,1).ae(1)
|
||||
assert eulerpoly(0,2).ae(1)
|
||||
assert eulerpoly(1,-1).ae('-3/2')
|
||||
assert eulerpoly(1,0).ae('-1/2')
|
||||
assert eulerpoly(1,'1/2').ae(0)
|
||||
assert eulerpoly(1,'3/4').ae('1/4')
|
||||
assert eulerpoly(1,1).ae('1/2')
|
||||
assert eulerpoly(1,2).ae('3/2')
|
||||
assert eulerpoly(2,-1).ae(2)
|
||||
assert eulerpoly(2,0).ae(0)
|
||||
assert eulerpoly(2,'1/2').ae('-1/4')
|
||||
assert eulerpoly(2,'3/4').ae('-3/16')
|
||||
assert eulerpoly(2,1).ae(0)
|
||||
assert eulerpoly(2,2).ae(2)
|
||||
assert eulerpoly(3,-1).ae('-9/4')
|
||||
assert eulerpoly(3,0).ae('1/4')
|
||||
assert eulerpoly(3,'1/2').ae(0)
|
||||
assert eulerpoly(3,'3/4').ae('-11/64')
|
||||
assert eulerpoly(3,1).ae('-1/4')
|
||||
assert eulerpoly(3,2).ae('9/4')
|
||||
assert eulerpoly(4,-1).ae(2)
|
||||
assert eulerpoly(4,0).ae(0)
|
||||
assert eulerpoly(4,'1/2').ae('5/16')
|
||||
assert eulerpoly(4,'3/4').ae('57/256')
|
||||
assert eulerpoly(4,1).ae(0)
|
||||
assert eulerpoly(4,2).ae(2)
|
||||
assert eulerpoly(5,-1).ae('-3/2')
|
||||
assert eulerpoly(5,0).ae('-1/2')
|
||||
assert eulerpoly(5,'1/2').ae(0)
|
||||
assert eulerpoly(5,'3/4').ae('361/1024')
|
||||
assert eulerpoly(5,1).ae('1/2')
|
||||
assert eulerpoly(5,2).ae('3/2')
|
||||
assert eulerpoly(10,-1).ae(2)
|
||||
assert eulerpoly(10,0).ae(0)
|
||||
assert eulerpoly(10,'1/2').ae('-50521/1024')
|
||||
assert eulerpoly(10,'3/4').ae('-36581523/1048576')
|
||||
assert eulerpoly(10,1).ae(0)
|
||||
assert eulerpoly(10,2).ae(2)
|
||||
assert eulerpoly(11,-1).ae('-699/4')
|
||||
assert eulerpoly(11,0).ae('691/4')
|
||||
assert eulerpoly(11,'1/2').ae(0)
|
||||
assert eulerpoly(11,'3/4').ae('-512343611/4194304')
|
||||
assert eulerpoly(11,1).ae('-691/4')
|
||||
assert eulerpoly(11,2).ae('699/4')
|
||||
# Potential accuracy issues
|
||||
assert bernpoly(10000,10000).ae('5.8196915936323387117e+39999')
|
||||
assert bernpoly(200,17.5).ae(3.8048418524583064909e244)
|
||||
assert eulerpoly(200,17.5).ae(-3.7309911582655785929e275)
|
||||
|
||||
def test_gamma():
|
||||
mp.dps = 15
|
||||
assert gamma(0.25).ae(3.6256099082219083119)
|
||||
assert gamma(0.0001).ae(9999.4228832316241908)
|
||||
assert gamma(300).ae('1.0201917073881354535e612')
|
||||
assert gamma(-0.5).ae(-3.5449077018110320546)
|
||||
assert gamma(-7.43).ae(0.00026524416464197007186)
|
||||
#assert gamma(Rational(1,2)) == gamma(0.5)
|
||||
#assert gamma(Rational(-7,3)).ae(gamma(mpf(-7)/3))
|
||||
assert gamma(1+1j).ae(0.49801566811835604271 - 0.15494982830181068512j)
|
||||
assert gamma(-1+0.01j).ae(-0.422733904013474115 + 99.985883082635367436j)
|
||||
assert gamma(20+30j).ae(-1453876687.5534810 + 1163777777.8031573j)
|
||||
# Should always give exact factorials when they can
|
||||
# be represented as mpfs under the current working precision
|
||||
fact = 1
|
||||
for i in range(1, 18):
|
||||
assert gamma(i) == fact
|
||||
fact *= i
|
||||
for dps in [170, 600]:
|
||||
fact = 1
|
||||
mp.dps = dps
|
||||
for i in range(1, 105):
|
||||
assert gamma(i) == fact
|
||||
fact *= i
|
||||
mp.dps = 100
|
||||
assert gamma(0.5).ae(sqrt(pi))
|
||||
mp.dps = 15
|
||||
assert factorial(0) == fac(0) == 1
|
||||
assert factorial(3) == 6
|
||||
assert isnan(gamma(nan))
|
||||
assert gamma(1100).ae('4.8579168073569433667e2866')
|
||||
|
||||
def test_fac2():
|
||||
mp.dps = 15
|
||||
assert [fac2(n) for n in range(10)] == [1,1,2,3,8,15,48,105,384,945]
|
||||
assert fac2(-5).ae(1./3)
|
||||
assert fac2(-11).ae(-1./945)
|
||||
assert fac2(50).ae(5.20469842636666623e32)
|
||||
assert fac2(0.5+0.75j).ae(0.81546769394688069176-0.34901016085573266889j)
|
||||
assert fac2(inf) == inf
|
||||
assert isnan(fac2(-inf))
|
||||
|
||||
def test_gamma_quotients():
|
||||
mp.dps = 15
|
||||
h = 1e-8
|
||||
ep = 1e-4
|
||||
G = gamma
|
||||
assert gammaprod([-1],[-3,-4]) == 0
|
||||
assert gammaprod([-1,0],[-5]) == inf
|
||||
assert abs(gammaprod([-1],[-2]) - G(-1+h)/G(-2+h)) < 1e-4
|
||||
assert abs(gammaprod([-4,-3],[-2,0]) - G(-4+h)*G(-3+h)/G(-2+h)/G(0+h)) < 1e-4
|
||||
assert rf(3,0) == 1
|
||||
assert rf(2.5,1) == 2.5
|
||||
assert rf(-5,2) == 20
|
||||
assert rf(j,j).ae(gamma(2*j)/gamma(j))
|
||||
assert ff(-2,0) == 1
|
||||
assert ff(-2,1) == -2
|
||||
assert ff(4,3) == 24
|
||||
assert ff(3,4) == 0
|
||||
assert binomial(0,0) == 1
|
||||
assert binomial(1,0) == 1
|
||||
assert binomial(0,-1) == 0
|
||||
assert binomial(3,2) == 3
|
||||
assert binomial(5,2) == 10
|
||||
assert binomial(5,3) == 10
|
||||
assert binomial(5,5) == 1
|
||||
assert binomial(-1,0) == 1
|
||||
assert binomial(-2,-4) == 3
|
||||
assert binomial(4.5, 1.5) == 6.5625
|
||||
assert binomial(1100,1) == 1100
|
||||
assert binomial(1100,2) == 604450
|
||||
assert beta(1,1) == 1
|
||||
assert beta(0,0) == inf
|
||||
assert beta(3,0) == inf
|
||||
assert beta(-1,-1) == inf
|
||||
assert beta(1.5,1).ae(2/3.)
|
||||
assert beta(1.5,2.5).ae(pi/16)
|
||||
assert (10**15*beta(10,100)).ae(2.3455339739604649879)
|
||||
assert beta(inf,inf) == 0
|
||||
assert isnan(beta(-inf,inf))
|
||||
assert isnan(beta(-3,inf))
|
||||
assert isnan(beta(0,inf))
|
||||
assert beta(inf,0.5) == beta(0.5,inf) == 0
|
||||
assert beta(inf,-1.5) == inf
|
||||
assert beta(inf,-0.5) == -inf
|
||||
assert beta(1+2j,-1-j/2).ae(1.16396542451069943086+0.08511695947832914640j)
|
||||
assert beta(-0.5,0.5) == 0
|
||||
assert beta(-3,3).ae(-1/3.)
|
||||
|
||||
def test_zeta():
|
||||
mp.dps = 15
|
||||
assert zeta(2).ae(pi**2 / 6)
|
||||
assert zeta(2.0).ae(pi**2 / 6)
|
||||
assert zeta(mpc(2)).ae(pi**2 / 6)
|
||||
assert zeta(100).ae(1)
|
||||
assert zeta(0).ae(-0.5)
|
||||
assert zeta(0.5).ae(-1.46035450880958681)
|
||||
assert zeta(-1).ae(-mpf(1)/12)
|
||||
assert zeta(-2) == 0
|
||||
assert zeta(-3).ae(mpf(1)/120)
|
||||
assert zeta(-4) == 0
|
||||
assert zeta(-100) == 0
|
||||
assert isnan(zeta(nan))
|
||||
# Zeros in the critical strip
|
||||
assert zeta(mpc(0.5, 14.1347251417346937904)).ae(0)
|
||||
assert zeta(mpc(0.5, 21.0220396387715549926)).ae(0)
|
||||
assert zeta(mpc(0.5, 25.0108575801456887632)).ae(0)
|
||||
mp.dps = 50
|
||||
im = '236.5242296658162058024755079556629786895294952121891237'
|
||||
assert zeta(mpc(0.5, im)).ae(0, 1e-46)
|
||||
mp.dps = 15
|
||||
# Complex reflection formula
|
||||
assert (zeta(-60+3j) / 10**34).ae(8.6270183987866146+15.337398548226238j)
|
||||
|
||||
def test_altzeta():
|
||||
mp.dps = 15
|
||||
assert altzeta(-2) == 0
|
||||
assert altzeta(-4) == 0
|
||||
assert altzeta(-100) == 0
|
||||
assert altzeta(0) == 0.5
|
||||
assert altzeta(-1) == 0.25
|
||||
assert altzeta(-3) == -0.125
|
||||
assert altzeta(-5) == 0.25
|
||||
assert altzeta(-21) == 1180529130.25
|
||||
assert altzeta(1).ae(log(2))
|
||||
assert altzeta(2).ae(pi**2/12)
|
||||
assert altzeta(10).ae(73*pi**10/6842880)
|
||||
assert altzeta(50) < 1
|
||||
assert altzeta(60, rounding='d') < 1
|
||||
assert altzeta(60, rounding='u') == 1
|
||||
assert altzeta(10000, rounding='d') < 1
|
||||
assert altzeta(10000, rounding='u') == 1
|
||||
assert altzeta(3+0j) == altzeta(3)
|
||||
s = 3+4j
|
||||
assert altzeta(s).ae((1-2**(1-s))*zeta(s))
|
||||
s = -3+4j
|
||||
assert altzeta(s).ae((1-2**(1-s))*zeta(s))
|
||||
assert altzeta(-100.5).ae(4.58595480083585913e+108)
|
||||
assert altzeta(1.3).ae(0.73821404216623045)
|
||||
|
||||
def test_zeta_huge():
|
||||
mp.dps = 15
|
||||
assert zeta(inf) == 1
|
||||
mp.dps = 50
|
||||
assert zeta(100).ae('1.0000000000000000000000000000007888609052210118073522')
|
||||
assert zeta(40*pi).ae('1.0000000000000000000000000000000000000148407238666182')
|
||||
mp.dps = 10000
|
||||
v = zeta(33000)
|
||||
mp.dps = 15
|
||||
assert str(v-1) == '1.02363019598118e-9934'
|
||||
assert zeta(pi*1000, rounding=round_up) > 1
|
||||
assert zeta(3000, rounding=round_up) > 1
|
||||
assert zeta(pi*1000) == 1
|
||||
assert zeta(3000) == 1
|
||||
|
||||
def test_zeta_negative():
|
||||
mp.dps = 150
|
||||
a = -pi*10**40
|
||||
mp.dps = 15
|
||||
assert str(zeta(a)) == '2.55880492708712e+1233536161668617575553892558646631323374078'
|
||||
mp.dps = 50
|
||||
assert str(zeta(a)) == '2.5588049270871154960875033337384432038436330847333e+1233536161668617575553892558646631323374078'
|
||||
mp.dps = 15
|
||||
|
||||
def test_polygamma():
|
||||
mp.dps = 15
|
||||
psi0 = lambda z: psi(0,z)
|
||||
psi1 = lambda z: psi(1,z)
|
||||
assert psi0(3) == psi(0,3) == digamma(3)
|
||||
#assert psi2(3) == psi(2,3) == tetragamma(3)
|
||||
#assert psi3(3) == psi(3,3) == pentagamma(3)
|
||||
assert psi0(pi).ae(0.97721330794200673)
|
||||
assert psi0(-pi).ae(7.8859523853854902)
|
||||
assert psi0(-pi+1).ae(7.5676424992016996)
|
||||
assert psi0(pi+j).ae(1.04224048313859376 + 0.35853686544063749j)
|
||||
assert psi0(-pi-j).ae(1.3404026194821986 - 2.8824392476809402j)
|
||||
assert findroot(psi0, 1).ae(1.4616321449683622)
|
||||
assert psi0(inf) == inf
|
||||
assert psi1(inf) == 0
|
||||
assert psi(2,inf) == 0
|
||||
assert psi1(pi).ae(0.37424376965420049)
|
||||
assert psi1(-pi).ae(53.030438740085385)
|
||||
assert psi1(pi+j).ae(0.32935710377142464 - 0.12222163911221135j)
|
||||
assert psi1(-pi-j).ae(-0.30065008356019703 + 0.01149892486928227j)
|
||||
assert (10**6*psi(4,1+10*pi*j)).ae(-6.1491803479004446 - 0.3921316371664063j)
|
||||
assert psi0(1+10*pi*j).ae(3.4473994217222650 + 1.5548808324857071j)
|
||||
assert isnan(psi0(nan))
|
||||
assert isnan(psi0(-inf))
|
||||
assert psi0(-100.5).ae(4.615124601338064)
|
||||
assert psi0(3+0j).ae(psi0(3))
|
||||
assert psi0(-100+3j).ae(4.6106071768714086321+3.1117510556817394626j)
|
||||
|
||||
def test_polygamma_high_prec():
|
||||
mp.dps = 100
|
||||
assert str(psi(0,pi)) == "0.9772133079420067332920694864061823436408346099943256380095232865318105924777141317302075654362928734"
|
||||
assert str(psi(10,pi)) == "-12.98876181434889529310283769414222588307175962213707170773803550518307617769657562747174101900659238"
|
||||
|
||||
def test_polygamma_identities():
|
||||
mp.dps = 15
|
||||
psi0 = lambda z: psi(0,z)
|
||||
psi1 = lambda z: psi(1,z)
|
||||
psi2 = lambda z: psi(2,z)
|
||||
assert psi0(0.5).ae(-euler-2*log(2))
|
||||
assert psi0(1).ae(-euler)
|
||||
assert psi1(0.5).ae(0.5*pi**2)
|
||||
assert psi1(1).ae(pi**2/6)
|
||||
assert psi1(0.25).ae(pi**2 + 8*catalan)
|
||||
assert psi2(1).ae(-2*apery)
|
||||
mp.dps = 20
|
||||
u = -182*apery+4*sqrt(3)*pi**3
|
||||
mp.dps = 15
|
||||
assert psi(2,5/6.).ae(u)
|
||||
assert psi(3,0.5).ae(pi**4)
|
||||
|
||||
def test_foxtrot_identity():
|
||||
# A test of the complex digamma function.
|
||||
# See http://mathworld.wolfram.com/FoxTrotSeries.html and
|
||||
# http://mathworld.wolfram.com/DigammaFunction.html
|
||||
psi0 = lambda z: psi(0,z)
|
||||
mp.dps = 50
|
||||
a = (-1)**fraction(1,3)
|
||||
b = (-1)**fraction(2,3)
|
||||
x = -psi0(0.5*a) - psi0(-0.5*b) + psi0(0.5*(1+a)) + psi0(0.5*(1-b))
|
||||
y = 2*pi*sech(0.5*sqrt(3)*pi)
|
||||
assert x.ae(y)
|
||||
mp.dps = 15
|
||||
|
||||
def test_polygamma_high_order():
|
||||
mp.dps = 100
|
||||
assert str(psi(50, pi)) == "-1344100348958402765749252447726432491812.641985273160531055707095989227897753035823152397679626136483"
|
||||
assert str(psi(50, pi + 14*e)) == "-0.00000000000000000189793739550804321623512073101895801993019919886375952881053090844591920308111549337295143780341396"
|
||||
assert str(psi(50, pi + 14*e*j)) == ("(-0.0000000000000000522516941152169248975225472155683565752375889510631513244785"
|
||||
"9377385233700094871256507814151956624433 - 0.00000000000000001813157041407010184"
|
||||
"702414110218205348527862196327980417757665282244728963891298080199341480881811613j)")
|
||||
mp.dps = 15
|
||||
assert str(psi(50, pi)) == "-1.34410034895841e+39"
|
||||
assert str(psi(50, pi + 14*e)) == "-1.89793739550804e-18"
|
||||
assert str(psi(50, pi + 14*e*j)) == "(-5.2251694115217e-17 - 1.81315704140701e-17j)"
|
||||
|
||||
def test_harmonic():
|
||||
mp.dps = 15
|
||||
assert harmonic(0) == 0
|
||||
assert harmonic(1) == 1
|
||||
assert harmonic(2) == 1.5
|
||||
assert harmonic(3).ae(1. + 1./2 + 1./3)
|
||||
assert harmonic(10**10).ae(23.603066594891989701)
|
||||
assert harmonic(10**1000).ae(2303.162308658947)
|
||||
assert harmonic(0.5).ae(2-2*log(2))
|
||||
assert harmonic(inf) == inf
|
||||
assert harmonic(2+0j) == 1.5+0j
|
||||
assert harmonic(1+2j).ae(1.4918071802755104+0.92080728264223022j)
|
||||
|
||||
def test_gamma_huge_1():
|
||||
mp.dps = 500
|
||||
x = mpf(10**10) / 7
|
||||
mp.dps = 15
|
||||
assert str(gamma(x)) == "6.26075321389519e+12458010678"
|
||||
mp.dps = 50
|
||||
assert str(gamma(x)) == "6.2607532138951929201303779291707455874010420783933e+12458010678"
|
||||
mp.dps = 15
|
||||
|
||||
def test_gamma_huge_2():
|
||||
mp.dps = 500
|
||||
x = mpf(10**100) / 19
|
||||
mp.dps = 15
|
||||
assert str(gamma(x)) == (\
|
||||
"1.82341134776679e+5172997469323364168990133558175077136829182824042201886051511"
|
||||
"9656908623426021308685461258226190190661")
|
||||
mp.dps = 50
|
||||
assert str(gamma(x)) == (\
|
||||
"1.82341134776678875374414910350027596939980412984e+5172997469323364168990133558"
|
||||
"1750771368291828240422018860515119656908623426021308685461258226190190661")
|
||||
|
||||
def test_gamma_huge_3():
|
||||
mp.dps = 500
|
||||
x = 10**80 // 3 + 10**70*j / 7
|
||||
mp.dps = 15
|
||||
y = gamma(x)
|
||||
assert str(y.real) == (\
|
||||
"-6.82925203918106e+2636286142112569524501781477865238132302397236429627932441916"
|
||||
"056964386399485392600")
|
||||
assert str(y.imag) == (\
|
||||
"8.54647143678418e+26362861421125695245017814778652381323023972364296279324419160"
|
||||
"56964386399485392600")
|
||||
mp.dps = 50
|
||||
y = gamma(x)
|
||||
assert str(y.real) == (\
|
||||
"-6.8292520391810548460682736226799637356016538421817e+26362861421125695245017814"
|
||||
"77865238132302397236429627932441916056964386399485392600")
|
||||
assert str(y.imag) == (\
|
||||
"8.5464714367841748507479306948130687511711420234015e+263628614211256952450178147"
|
||||
"7865238132302397236429627932441916056964386399485392600")
|
||||
|
||||
def test_gamma_huge_4():
|
||||
x = 3200+11500j
|
||||
mp.dps = 15
|
||||
assert str(gamma(x)) == \
|
||||
"(8.95783268539713e+5164 - 1.94678798329735e+5164j)"
|
||||
mp.dps = 50
|
||||
assert str(gamma(x)) == (\
|
||||
"(8.9578326853971339570292952697675570822206567327092e+5164"
|
||||
" - 1.9467879832973509568895402139429643650329524144794e+51"
|
||||
"64j)")
|
||||
mp.dps = 15
|
||||
|
||||
def test_gamma_huge_5():
|
||||
mp.dps = 500
|
||||
x = 10**60 * j / 3
|
||||
mp.dps = 15
|
||||
y = gamma(x)
|
||||
assert str(y.real) == "-3.27753899634941e-227396058973640224580963937571892628368354580620654233316839"
|
||||
assert str(y.imag) == "-7.1519888950416e-227396058973640224580963937571892628368354580620654233316841"
|
||||
mp.dps = 50
|
||||
y = gamma(x)
|
||||
assert str(y.real) == (\
|
||||
"-3.2775389963494132168950056995974690946983219123935e-22739605897364022458096393"
|
||||
"7571892628368354580620654233316839")
|
||||
assert str(y.imag) == (\
|
||||
"-7.1519888950415979749736749222530209713136588885897e-22739605897364022458096393"
|
||||
"7571892628368354580620654233316841")
|
||||
mp.dps = 15
|
||||
|
||||
"""
|
||||
XXX: fails
|
||||
def test_gamma_huge_6():
|
||||
return
|
||||
mp.dps = 500
|
||||
x = -10**10 + mpf(10)**(-175)*j
|
||||
mp.dps = 15
|
||||
assert str(gamma(x)) == \
|
||||
"(1.86729378905343e-95657055178 - 4.29960285282433e-95657055002j)"
|
||||
mp.dps = 50
|
||||
assert str(gamma(x)) == (\
|
||||
"(1.8672937890534298925763143275474177736153484820662e-9565705517"
|
||||
"8 - 4.2996028528243336966001185406200082244961757496106e-9565705"
|
||||
"5002j)")
|
||||
mp.dps = 15
|
||||
"""
|
||||
|
||||
def test_gamma_huge_7():
|
||||
mp.dps = 100
|
||||
a = 3 + j/mpf(10)**1000
|
||||
mp.dps = 15
|
||||
y = gamma(a)
|
||||
assert str(y.real) == "2.0"
|
||||
assert str(y.imag) == "2.16735365342606e-1000"
|
||||
mp.dps = 50
|
||||
y = gamma(a)
|
||||
assert str(y.real) == "2.0"
|
||||
assert str(y.imag) == "2.1673536534260596065418805612488708028522563689298e-1000"
|
||||
|
||||
def test_stieltjes():
|
||||
mp.dps = 15
|
||||
assert stieltjes(0).ae(+euler)
|
||||
mp.dps = 25
|
||||
assert stieltjes(1).ae('-0.07281584548367672486058637587')
|
||||
assert stieltjes(2).ae('-0.009690363192872318484530386035')
|
||||
assert stieltjes(3).ae('0.002053834420303345866160046543')
|
||||
assert stieltjes(4).ae('0.002325370065467300057468170178')
|
||||
mp.dps = 15
|
||||
assert stieltjes(1).ae(-0.07281584548367672486058637587)
|
||||
assert stieltjes(2).ae(-0.009690363192872318484530386035)
|
||||
assert stieltjes(3).ae(0.002053834420303345866160046543)
|
||||
assert stieltjes(4).ae(0.0023253700654673000574681701775)
|
||||
|
||||
def test_barnesg():
|
||||
mp.dps = 15
|
||||
assert barnesg(0) == barnesg(-1) == 0
|
||||
assert [superfac(i) for i in range(8)] == [1, 1, 2, 12, 288, 34560, 24883200, 125411328000]
|
||||
assert str(superfac(1000)) == '3.24570818422368e+1177245'
|
||||
assert isnan(barnesg(nan))
|
||||
assert isnan(superfac(nan))
|
||||
assert isnan(hyperfac(nan))
|
||||
assert barnesg(inf) == inf
|
||||
assert superfac(inf) == inf
|
||||
assert hyperfac(inf) == inf
|
||||
assert isnan(superfac(-inf))
|
||||
assert barnesg(0.7).ae(0.8068722730141471)
|
||||
assert barnesg(2+3j).ae(-0.17810213864082169+0.04504542715447838j)
|
||||
assert [hyperfac(n) for n in range(7)] == [1, 1, 4, 108, 27648, 86400000, 4031078400000]
|
||||
assert [hyperfac(n) for n in range(0,-7,-1)] == [1,1,-1,-4,108,27648,-86400000]
|
||||
a = barnesg(-3+0j)
|
||||
assert a == 0 and isinstance(a, mpc)
|
||||
a = hyperfac(-3+0j)
|
||||
assert a == -4 and isinstance(a, mpc)
|
||||
|
||||
def test_polylog():
|
||||
mp.dps = 15
|
||||
zs = [mpmathify(z) for z in [0, 0.5, 0.99, 4, -0.5, -4, 1j, 3+4j]]
|
||||
for z in zs: assert polylog(1, z).ae(-log(1-z))
|
||||
for z in zs: assert polylog(0, z).ae(z/(1-z))
|
||||
for z in zs: assert polylog(-1, z).ae(z/(1-z)**2)
|
||||
for z in zs: assert polylog(-2, z).ae(z*(1+z)/(1-z)**3)
|
||||
for z in zs: assert polylog(-3, z).ae(z*(1+4*z+z**2)/(1-z)**4)
|
||||
assert polylog(3, 7).ae(5.3192579921456754382-5.9479244480803301023j)
|
||||
assert polylog(3, -7).ae(-4.5693548977219423182)
|
||||
assert polylog(2, 0.9).ae(1.2997147230049587252)
|
||||
assert polylog(2, -0.9).ae(-0.75216317921726162037)
|
||||
assert polylog(2, 0.9j).ae(-0.17177943786580149299+0.83598828572550503226j)
|
||||
assert polylog(2, 1.1).ae(1.9619991013055685931-0.2994257606855892575j)
|
||||
assert polylog(2, -1.1).ae(-0.89083809026228260587)
|
||||
assert polylog(2, 1.1*sqrt(j)).ae(0.58841571107611387722+1.09962542118827026011j)
|
||||
assert polylog(-2, 0.9).ae(1710)
|
||||
assert polylog(-2, -0.9).ae(-90/6859.)
|
||||
assert polylog(3, 0.9).ae(1.0496589501864398696)
|
||||
assert polylog(-3, 0.9).ae(48690)
|
||||
assert polylog(-3, -4).ae(-0.0064)
|
||||
assert polylog(0.5+j/3, 0.5+j/2).ae(0.31739144796565650535 + 0.99255390416556261437j)
|
||||
assert polylog(3+4j,1).ae(zeta(3+4j))
|
||||
assert polylog(3+4j,-1).ae(-altzeta(3+4j))
|
||||
|
||||
def test_bell_polyexp():
|
||||
mp.dps = 15
|
||||
# TODO: more tests for polyexp
|
||||
assert (polyexp(0,1e-10)*10**10).ae(1.00000000005)
|
||||
assert (polyexp(1,1e-10)*10**10).ae(1.0000000001)
|
||||
assert polyexp(5,3j).ae(-607.7044517476176454+519.962786482001476087j)
|
||||
assert polyexp(-1,3.5).ae(12.09537536175543444)
|
||||
# bell(0,x) = 1
|
||||
assert bell(0,0) == 1
|
||||
assert bell(0,1) == 1
|
||||
assert bell(0,2) == 1
|
||||
assert bell(0,inf) == 1
|
||||
assert bell(0,-inf) == 1
|
||||
assert isnan(bell(0,nan))
|
||||
# bell(1,x) = x
|
||||
assert bell(1,4) == 4
|
||||
assert bell(1,0) == 0
|
||||
assert bell(1,inf) == inf
|
||||
assert bell(1,-inf) == -inf
|
||||
assert isnan(bell(1,nan))
|
||||
# bell(2,x) = x*(1+x)
|
||||
assert bell(2,-1) == 0
|
||||
assert bell(2,0) == 0
|
||||
# large orders / arguments
|
||||
assert bell(10) == 115975
|
||||
assert bell(10,1) == 115975
|
||||
assert bell(10, -8) == 11054008
|
||||
assert bell(5,-50) == -253087550
|
||||
assert bell(50,-50).ae('3.4746902914629720259e74')
|
||||
mp.dps = 80
|
||||
assert bell(50,-50) == 347469029146297202586097646631767227177164818163463279814268368579055777450
|
||||
assert bell(40,50) == 5575520134721105844739265207408344706846955281965031698187656176321717550
|
||||
assert bell(74) == 5006908024247925379707076470957722220463116781409659160159536981161298714301202
|
||||
mp.dps = 15
|
||||
assert bell(10,20j) == 7504528595600+15649605360020j
|
||||
# continuity of the generalization
|
||||
assert bell(0.5,0).ae(sinc(pi*0.5))
|
||||
|
||||
def test_primezeta():
|
||||
mp.dps = 15
|
||||
assert primezeta(0.9).ae(1.8388316154446882243 + 3.1415926535897932385j)
|
||||
assert primezeta(4).ae(0.076993139764246844943)
|
||||
assert primezeta(1) == inf
|
||||
assert primezeta(inf) == 0
|
||||
assert isnan(primezeta(nan))
|
||||
|
||||
def test_rs_zeta():
|
||||
mp.dps = 15
|
||||
assert zeta(0.5+100000j).ae(1.0730320148577531321 + 5.7808485443635039843j)
|
||||
assert zeta(0.75+100000j).ae(1.837852337251873704 + 1.9988492668661145358j)
|
||||
assert zeta(0.5+1000000j, derivative=3).ae(1647.7744105852674733 - 1423.1270943036622097j)
|
||||
assert zeta(1+1000000j, derivative=3).ae(3.4085866124523582894 - 18.179184721525947301j)
|
||||
assert zeta(1+1000000j, derivative=1).ae(-0.10423479366985452134 - 0.74728992803359056244j)
|
||||
assert zeta(0.5-1000000j, derivative=1).ae(11.636804066002521459 + 17.127254072212996004j)
|
||||
# Additional sanity tests using fp arithmetic.
|
||||
# Some more high-precision tests are found in the docstrings
|
||||
def ae(x, y, tol=1e-6):
|
||||
return abs(x-y) < tol*abs(y)
|
||||
assert ae(fp.zeta(0.5-100000j), 1.0730320148577531321 - 5.7808485443635039843j)
|
||||
assert ae(fp.zeta(0.75-100000j), 1.837852337251873704 - 1.9988492668661145358j)
|
||||
assert ae(fp.zeta(0.5+1e6j), 0.076089069738227100006 + 2.8051021010192989554j)
|
||||
assert ae(fp.zeta(0.5+1e6j, derivative=1), 11.636804066002521459 - 17.127254072212996004j)
|
||||
assert ae(fp.zeta(1+1e6j), 0.94738726251047891048 + 0.59421999312091832833j)
|
||||
assert ae(fp.zeta(1+1e6j, derivative=1), -0.10423479366985452134 - 0.74728992803359056244j)
|
||||
assert ae(fp.zeta(0.5+100000j, derivative=1), 10.766962036817482375 - 30.92705282105996714j)
|
||||
assert ae(fp.zeta(0.5+100000j, derivative=2), -119.40515625740538429 + 217.14780631141830251j)
|
||||
assert ae(fp.zeta(0.5+100000j, derivative=3), 1129.7550282628460881 - 1685.4736895169690346j)
|
||||
assert ae(fp.zeta(0.5+100000j, derivative=4), -10407.160819314958615 + 13777.786698628045085j)
|
||||
assert ae(fp.zeta(0.75+100000j, derivative=1), -0.41742276699594321475 - 6.4453816275049955949j)
|
||||
assert ae(fp.zeta(0.75+100000j, derivative=2), -9.214314279161977266 + 35.07290795337967899j)
|
||||
assert ae(fp.zeta(0.75+100000j, derivative=3), 110.61331857820103469 - 236.87847130518129926j)
|
||||
assert ae(fp.zeta(0.75+100000j, derivative=4), -1054.334275898559401 + 1769.9177890161596383j)
|
||||
|
||||
def test_zeta_near_1():
|
||||
# Test for a former bug in mpf_zeta and mpc_zeta
|
||||
mp.dps = 15
|
||||
s1 = fadd(1, '1e-10', exact=True)
|
||||
s2 = fadd(1, '-1e-10', exact=True)
|
||||
s3 = fadd(1, '1e-10j', exact=True)
|
||||
assert zeta(s1).ae(1.000000000057721566490881444e10)
|
||||
assert zeta(s2).ae(-9.99999999942278433510574872e9)
|
||||
z = zeta(s3)
|
||||
assert z.real.ae(0.57721566490153286060)
|
||||
assert z.imag.ae(-9.9999999999999999999927184e9)
|
||||
mp.dps = 30
|
||||
s1 = fadd(1, '1e-50', exact=True)
|
||||
s2 = fadd(1, '-1e-50', exact=True)
|
||||
s3 = fadd(1, '1e-50j', exact=True)
|
||||
assert zeta(s1).ae('1e50')
|
||||
assert zeta(s2).ae('-1e50')
|
||||
z = zeta(s3)
|
||||
assert z.real.ae('0.57721566490153286060651209008240243104215933593992')
|
||||
assert z.imag.ae('-1e50')
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
"""
|
||||
Check that the output from irrational functions is accurate for
|
||||
high-precision input, from 5 to 200 digits. The reference values were
|
||||
verified with Mathematica.
|
||||
"""
|
||||
|
||||
import time
|
||||
from mpmath import *
|
||||
|
||||
precs = [5, 15, 28, 35, 57, 80, 100, 150, 200]
|
||||
|
||||
# sqrt(3) + pi/2
|
||||
a = \
|
||||
"3.302847134363773912758768033145623809041389953497933538543279275605"\
|
||||
"841220051904536395163599428307109666700184672047856353516867399774243594"\
|
||||
"67433521615861420725323528325327484262075464241255915238845599752675"
|
||||
|
||||
# e + 1/euler**2
|
||||
b = \
|
||||
"5.719681166601007617111261398629939965860873957353320734275716220045750"\
|
||||
"31474116300529519620938123730851145473473708966080207482581266469342214"\
|
||||
"824842256999042984813905047895479210702109260221361437411947323431"
|
||||
|
||||
# sqrt(a)
|
||||
sqrt_a = \
|
||||
"1.817373691447021556327498239690365674922395036495564333152483422755"\
|
||||
"144321726165582817927383239308173567921345318453306994746434073691275094"\
|
||||
"484777905906961689902608644112196725896908619756404253109722911487"
|
||||
|
||||
# sqrt(a+b*i).real
|
||||
sqrt_abi_real = \
|
||||
"2.225720098415113027729407777066107959851146508557282707197601407276"\
|
||||
"89160998185797504198062911768240808839104987021515555650875977724230130"\
|
||||
"3584116233925658621288393930286871862273400475179312570274423840384"
|
||||
|
||||
# sqrt(a+b*i).imag
|
||||
sqrt_abi_imag = \
|
||||
"1.2849057639084690902371581529110949983261182430040898147672052833653668"\
|
||||
"0629534491275114877090834296831373498336559849050755848611854282001250"\
|
||||
"1924311019152914021365263161630765255610885489295778894976075186"
|
||||
|
||||
# log(a)
|
||||
log_a = \
|
||||
"1.194784864491089550288313512105715261520511949410072046160598707069"\
|
||||
"4336653155025770546309137440687056366757650909754708302115204338077595203"\
|
||||
"83005773986664564927027147084436553262269459110211221152925732612"
|
||||
|
||||
# log(a+b*i).real
|
||||
log_abi_real = \
|
||||
"1.8877985921697018111624077550443297276844736840853590212962006811663"\
|
||||
"04949387789489704203167470111267581371396245317618589339274243008242708"\
|
||||
"014251531496104028712866224020066439049377679709216784954509456421"
|
||||
|
||||
# log(a+b*i).imag
|
||||
log_abi_imag = \
|
||||
"1.0471204952840802663567714297078763189256357109769672185219334169734948"\
|
||||
"4265809854092437285294686651806426649541504240470168212723133326542181"\
|
||||
"8300136462287639956713914482701017346851009323172531601894918640"
|
||||
|
||||
# exp(a)
|
||||
exp_a = \
|
||||
"27.18994224087168661137253262213293847994194869430518354305430976149"\
|
||||
"382792035050358791398632888885200049857986258414049540376323785711941636"\
|
||||
"100358982497583832083513086941635049329804685212200507288797531143"
|
||||
|
||||
# exp(a+b*i).real
|
||||
exp_abi_real = \
|
||||
"22.98606617170543596386921087657586890620262522816912505151109385026"\
|
||||
"40160179326569526152851983847133513990281518417211964710397233157168852"\
|
||||
"4963130831190142571659948419307628119985383887599493378056639916701"
|
||||
|
||||
# exp(a+b*i).imag
|
||||
exp_abi_imag = \
|
||||
"-14.523557450291489727214750571590272774669907424478129280902375851196283"\
|
||||
"3377162379031724734050088565710975758824441845278120105728824497308303"\
|
||||
"6065619788140201636218705414429933685889542661364184694108251449"
|
||||
|
||||
# a**b
|
||||
pow_a_b = \
|
||||
"928.7025342285568142947391505837660251004990092821305668257284426997"\
|
||||
"361966028275685583421197860603126498884545336686124793155581311527995550"\
|
||||
"580229264427202446131740932666832138634013168125809402143796691154"
|
||||
|
||||
# (a**(a+b*i)).real
|
||||
pow_a_abi_real = \
|
||||
"44.09156071394489511956058111704382592976814280267142206420038656267"\
|
||||
"67707916510652790502399193109819563864568986234654864462095231138500505"\
|
||||
"8197456514795059492120303477512711977915544927440682508821426093455"
|
||||
|
||||
# (a**(a+b*i)).imag
|
||||
pow_a_abi_imag = \
|
||||
"27.069371511573224750478105146737852141664955461266218367212527612279886"\
|
||||
"9322304536553254659049205414427707675802193810711302947536332040474573"\
|
||||
"8166261217563960235014674118610092944307893857862518964990092301"
|
||||
|
||||
# ((a+b*i)**(a+b*i)).real
|
||||
pow_abi_abi_real = \
|
||||
"-0.15171310677859590091001057734676423076527145052787388589334350524"\
|
||||
"8084195882019497779202452975350579073716811284169068082670778986235179"\
|
||||
"0813026562962084477640470612184016755250592698408112493759742219150452"\
|
||||
|
||||
# ((a+b*i)**(a+b*i)).imag
|
||||
pow_abi_abi_imag = \
|
||||
"1.2697592504953448936553147870155987153192995316950583150964099070426"\
|
||||
"4736837932577176947632535475040521749162383347758827307504526525647759"\
|
||||
"97547638617201824468382194146854367480471892602963428122896045019902"
|
||||
|
||||
# sin(a)
|
||||
sin_a = \
|
||||
"-0.16055653857469062740274792907968048154164433772938156243509084009"\
|
||||
"38437090841460493108570147191289893388608611542655654723437248152535114"\
|
||||
"528368009465836614227575701220612124204622383149391870684288862269631"
|
||||
|
||||
# sin(1000*a)
|
||||
sin_1000a = \
|
||||
"-0.85897040577443833776358106803777589664322997794126153477060795801"\
|
||||
"09151695416961724733492511852267067419573754315098042850381158563024337"\
|
||||
"216458577140500488715469780315833217177634490142748614625281171216863"
|
||||
|
||||
# sin(a+b*i)
|
||||
sin_abi_real = \
|
||||
"-24.4696999681556977743346798696005278716053366404081910969773939630"\
|
||||
"7149215135459794473448465734589287491880563183624997435193637389884206"\
|
||||
"02151395451271809790360963144464736839412254746645151672423256977064"
|
||||
|
||||
sin_abi_imag = \
|
||||
"-150.42505378241784671801405965872972765595073690984080160750785565810981"\
|
||||
"8314482499135443827055399655645954830931316357243750839088113122816583"\
|
||||
"7169201254329464271121058839499197583056427233866320456505060735"
|
||||
|
||||
# cos
|
||||
cos_a = \
|
||||
"-0.98702664499035378399332439243967038895709261414476495730788864004"\
|
||||
"05406821549361039745258003422386169330787395654908532996287293003581554"\
|
||||
"257037193284199198069707141161341820684198547572456183525659969145501"
|
||||
|
||||
cos_1000a = \
|
||||
"-0.51202523570982001856195696460663971099692261342827540426136215533"\
|
||||
"52686662667660613179619804463250686852463876088694806607652218586060613"\
|
||||
"951310588158830695735537073667299449753951774916401887657320950496820"
|
||||
|
||||
# tan
|
||||
tan_a = \
|
||||
"0.162666873675188117341401059858835168007137819495998960250142156848"\
|
||||
"639654718809412181543343168174807985559916643549174530459883826451064966"\
|
||||
"7996119428949951351938178809444268785629011625179962457123195557310"
|
||||
|
||||
tan_abi_real = \
|
||||
"6.822696615947538488826586186310162599974827139564433912601918442911"\
|
||||
"1026830824380070400102213741875804368044342309515353631134074491271890"\
|
||||
"467615882710035471686578162073677173148647065131872116479947620E-6"
|
||||
|
||||
tan_abi_imag = \
|
||||
"0.9999795833048243692245661011298447587046967777739649018690797625964167"\
|
||||
"1446419978852235960862841608081413169601038230073129482874832053357571"\
|
||||
"62702259309150715669026865777947502665936317953101462202542168429"
|
||||
|
||||
|
||||
def test_hp():
|
||||
for dps in precs:
|
||||
mp.dps = dps + 8
|
||||
aa = mpf(a)
|
||||
bb = mpf(b)
|
||||
a1000 = 1000*mpf(a)
|
||||
abi = mpc(aa, bb)
|
||||
mp.dps = dps
|
||||
assert (sqrt(3) + pi/2).ae(aa)
|
||||
assert (e + 1/euler**2).ae(bb)
|
||||
|
||||
assert sqrt(aa).ae(mpf(sqrt_a))
|
||||
assert sqrt(abi).ae(mpc(sqrt_abi_real, sqrt_abi_imag))
|
||||
|
||||
assert log(aa).ae(mpf(log_a))
|
||||
assert log(abi).ae(mpc(log_abi_real, log_abi_imag))
|
||||
|
||||
assert exp(aa).ae(mpf(exp_a))
|
||||
assert exp(abi).ae(mpc(exp_abi_real, exp_abi_imag))
|
||||
|
||||
assert (aa**bb).ae(mpf(pow_a_b))
|
||||
assert (aa**abi).ae(mpc(pow_a_abi_real, pow_a_abi_imag))
|
||||
assert (abi**abi).ae(mpc(pow_abi_abi_real, pow_abi_abi_imag))
|
||||
|
||||
assert sin(a).ae(mpf(sin_a))
|
||||
assert sin(a1000).ae(mpf(sin_1000a))
|
||||
assert sin(abi).ae(mpc(sin_abi_real, sin_abi_imag))
|
||||
|
||||
assert cos(a).ae(mpf(cos_a))
|
||||
assert cos(a1000).ae(mpf(cos_1000a))
|
||||
|
||||
assert tan(a).ae(mpf(tan_a))
|
||||
assert tan(abi).ae(mpc(tan_abi_real, tan_abi_imag))
|
||||
|
||||
# check that complex cancellation is avoided so that both
|
||||
# real and imaginary parts have high relative accuracy.
|
||||
# abs_eps should be 0, but has to be set to 1e-205 to pass the
|
||||
# 200-digit case, probably due to slight inaccuracy in the
|
||||
# precomputed input
|
||||
assert (tan(abi).real).ae(mpf(tan_abi_real), abs_eps=1e-205)
|
||||
assert (tan(abi).imag).ae(mpf(tan_abi_imag), abs_eps=1e-205)
|
||||
mp.dps = 460
|
||||
assert str(log(3))[-20:] == '02166121184001409826'
|
||||
mp.dps = 15
|
||||
|
||||
# Since str(a) can differ in the last digit from rounded a, and I want
|
||||
# to compare the last digits of big numbers with the results in Mathematica,
|
||||
# I made this hack to get the last 20 digits of rounded a
|
||||
|
||||
def last_digits(a):
|
||||
r = repr(a)
|
||||
s = str(a)
|
||||
#dps = mp.dps
|
||||
#mp.dps += 3
|
||||
m = 10
|
||||
r = r.replace(s[:-m],'')
|
||||
r = r.replace("mpf('",'').replace("')",'')
|
||||
num0 = 0
|
||||
for c in r:
|
||||
if c == '0':
|
||||
num0 += 1
|
||||
else:
|
||||
break
|
||||
b = float(int(r))/10**(len(r) - m)
|
||||
if b >= 10**m - 0.5:
|
||||
raise NotImplementedError
|
||||
n = int(round(b))
|
||||
sn = str(n)
|
||||
s = s[:-m] + '0'*num0 + sn
|
||||
return s[-20:]
|
||||
|
||||
# values checked with Mathematica
|
||||
def test_log_hp():
|
||||
mp.dps = 2000
|
||||
a = mpf(10)**15000/3
|
||||
r = log(a)
|
||||
res = last_digits(r)
|
||||
# Mathematica N[Log[10^15000/3], 2000]
|
||||
# ...7443804441768333470331
|
||||
assert res == '44380444176833347033'
|
||||
|
||||
# see issue 105
|
||||
r = log(mpf(3)/2)
|
||||
# Mathematica N[Log[3/2], 2000]
|
||||
# ...69653749808140753263288
|
||||
res = last_digits(r)
|
||||
assert res == '53749808140753263288'
|
||||
|
||||
mp.dps = 10000
|
||||
r = log(2)
|
||||
res = last_digits(r)
|
||||
# Mathematica N[Log[2], 10000]
|
||||
# ...695615913401856601359655561
|
||||
assert res == '91340185660135965556'
|
||||
r = log(mpf(10)**10/3)
|
||||
res = last_digits(r)
|
||||
# Mathematica N[Log[10^10/3], 10000]
|
||||
# ...587087654020631943060007154
|
||||
assert res == '54020631943060007154', res
|
||||
r = log(mpf(10)**100/3)
|
||||
res = last_digits(r)
|
||||
# Mathematica N[Log[10^100/3], 10000]
|
||||
# ,,,59246336539088351652334666
|
||||
assert res == '36539088351652334666', res
|
||||
mp.dps += 10
|
||||
a = 1 - mpf(1)/10**10
|
||||
mp.dps -= 10
|
||||
r = log(a)
|
||||
res = last_digits(r)
|
||||
# ...3310334360482956137216724048322957404
|
||||
# 372167240483229574038733026370
|
||||
# Mathematica N[Log[1 - 10^-10]*10^10, 10000]
|
||||
# ...60482956137216724048322957404
|
||||
assert res == '37216724048322957404', res
|
||||
mp.dps = 10000
|
||||
mp.dps += 100
|
||||
a = 1 + mpf(1)/10**100
|
||||
mp.dps -= 100
|
||||
|
||||
r = log(a)
|
||||
res = last_digits(+r)
|
||||
# Mathematica N[Log[1 + 10^-100]*10^10, 10030]
|
||||
# ...3994733877377412241546890854692521568292338268273 10^-91
|
||||
assert res == '39947338773774122415', res
|
||||
|
||||
mp.dps = 15
|
||||
|
||||
def test_exp_hp():
|
||||
mp.dps = 4000
|
||||
r = exp(mpf(1)/10)
|
||||
# IntegerPart[N[Exp[1/10] * 10^4000, 4000]]
|
||||
# ...92167105162069688129
|
||||
assert int(r * 10**mp.dps) % 10**20 == 92167105162069688129
|
||||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_pslq():
|
||||
mp.dps = 15
|
||||
assert pslq([3*pi+4*e/7, pi, e, log(2)]) == [7, -21, -4, 0]
|
||||
assert pslq([4.9999999999999991, 1]) == [1, -5]
|
||||
assert pslq([2,1]) == [1, -2]
|
||||
|
||||
def test_identify():
|
||||
mp.dps = 20
|
||||
assert identify(zeta(4), ['log(2)', 'pi**4']) == '((1/90)*pi**4)'
|
||||
mp.dps = 15
|
||||
assert identify(exp(5)) == 'exp(5)'
|
||||
assert identify(exp(4)) == 'exp(4)'
|
||||
assert identify(log(5)) == 'log(5)'
|
||||
assert identify(exp(3*pi), ['pi']) == 'exp((3*pi))'
|
||||
assert identify(3, full=True) == ['3', '3', '1/(1/3)', 'sqrt(9)',
|
||||
'1/sqrt((1/9))', '(sqrt(12)/2)**2', '1/(sqrt(12)/6)**2']
|
||||
assert identify(pi+1, {'a':+pi}) == '(1 + 1*a)'
|
||||
|
|
@ -1,264 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
mpi_to_str = mp.mpi_to_str
|
||||
mpi_from_str = mp.mpi_from_str
|
||||
|
||||
def test_interval_identity():
|
||||
mp.dps = 15
|
||||
assert mpi(2) == mpi(2, 2)
|
||||
assert mpi(2) != mpi(-2, 2)
|
||||
assert not (mpi(2) != mpi(2, 2))
|
||||
assert mpi(-1, 1) == mpi(-1, 1)
|
||||
assert str(mpi('0.1')) == "[0.099999999999999991673, 0.10000000000000000555]"
|
||||
assert repr(mpi('0.1')) == "mpi(mpf('0.099999999999999992'), mpf('0.10000000000000001'))"
|
||||
u = mpi(-1, 3)
|
||||
assert -1 in u
|
||||
assert 2 in u
|
||||
assert 3 in u
|
||||
assert -1.1 not in u
|
||||
assert 3.1 not in u
|
||||
assert mpi(-1, 3) in u
|
||||
assert mpi(0, 1) in u
|
||||
assert mpi(-1.1, 2) not in u
|
||||
assert mpi(2.5, 3.1) not in u
|
||||
w = mpi(-inf, inf)
|
||||
assert mpi(-5, 5) in w
|
||||
assert mpi(2, inf) in w
|
||||
assert mpi(0, 2) in mpi(0, 10)
|
||||
assert not (3 in mpi(-inf, 0))
|
||||
|
||||
def test_interval_arithmetic():
|
||||
mp.dps = 15
|
||||
assert mpi(2) + mpi(3,4) == mpi(5,6)
|
||||
assert mpi(1, 2)**2 == mpi(1, 4)
|
||||
assert mpi(1) + mpi(0, 1e-50) == mpi(1, mpf('1.0000000000000002'))
|
||||
x = 1 / (1 / mpi(3))
|
||||
assert x.a < 3 < x.b
|
||||
x = mpi(2) ** mpi(0.5)
|
||||
mp.dps += 5
|
||||
sq = sqrt(2)
|
||||
mp.dps -= 5
|
||||
assert x.a < sq < x.b
|
||||
assert mpi(1) / mpi(1, inf)
|
||||
assert mpi(2, 3) / inf == mpi(0, 0)
|
||||
assert mpi(0) / inf == 0
|
||||
assert mpi(0) / 0 == mpi(-inf, inf)
|
||||
assert mpi(inf) / 0 == mpi(-inf, inf)
|
||||
assert mpi(0) * inf == mpi(-inf, inf)
|
||||
assert 1 / mpi(2, inf) == mpi(0, 0.5)
|
||||
assert str((mpi(50, 50) * mpi(-10, -10)) / 3) == \
|
||||
'[-166.66666666666668561, -166.66666666666665719]'
|
||||
assert mpi(0, 4) ** 3 == mpi(0, 64)
|
||||
assert mpi(2,4).mid == 3
|
||||
mp.dps = 30
|
||||
a = mpi(pi)
|
||||
mp.dps = 15
|
||||
b = +a
|
||||
assert b.a < a.a
|
||||
assert b.b > a.b
|
||||
a = mpi(pi)
|
||||
assert a == +a
|
||||
assert abs(mpi(-1,2)) == mpi(0,2)
|
||||
assert abs(mpi(0.5,2)) == mpi(0.5,2)
|
||||
assert abs(mpi(-3,2)) == mpi(0,3)
|
||||
assert abs(mpi(-3,-0.5)) == mpi(0.5,3)
|
||||
assert mpi(0) * mpi(2,3) == mpi(0)
|
||||
assert mpi(2,3) * mpi(0) == mpi(0)
|
||||
assert mpi(1,3).delta == 2
|
||||
assert mpi(1,2) - mpi(3,4) == mpi(-3,-1)
|
||||
assert mpi(-inf,0) - mpi(0,inf) == mpi(-inf,0)
|
||||
assert mpi(-inf,0) - mpi(-inf,inf) == mpi(-inf,inf)
|
||||
assert mpi(0,inf) - mpi(-inf,1) == mpi(-1,inf)
|
||||
|
||||
def test_interval_mul():
|
||||
assert mpi(-1, 0) * inf == mpi(-inf, 0)
|
||||
assert mpi(-1, 0) * -inf == mpi(0, inf)
|
||||
assert mpi(0, 1) * inf == mpi(0, inf)
|
||||
assert mpi(0, 1) * mpi(0, inf) == mpi(0, inf)
|
||||
assert mpi(-1, 1) * inf == mpi(-inf, inf)
|
||||
assert mpi(-1, 1) * mpi(0, inf) == mpi(-inf, inf)
|
||||
assert mpi(-1, 1) * mpi(-inf, inf) == mpi(-inf, inf)
|
||||
assert mpi(-inf, 0) * mpi(0, 1) == mpi(-inf, 0)
|
||||
assert mpi(-inf, 0) * mpi(0, 0) * mpi(-inf, 0)
|
||||
assert mpi(-inf, 0) * mpi(-inf, inf) == mpi(-inf, inf)
|
||||
assert mpi(-5,0)*mpi(-32,28) == mpi(-140,160)
|
||||
assert mpi(2,3) * mpi(-1,2) == mpi(-3,6)
|
||||
# Should be undefined?
|
||||
assert mpi(inf, inf) * 0 == mpi(-inf, inf)
|
||||
assert mpi(-inf, -inf) * 0 == mpi(-inf, inf)
|
||||
assert mpi(0) * mpi(-inf,2) == mpi(-inf,inf)
|
||||
assert mpi(0) * mpi(-2,inf) == mpi(-inf,inf)
|
||||
assert mpi(-2,inf) * mpi(0) == mpi(-inf,inf)
|
||||
assert mpi(-inf,2) * mpi(0) == mpi(-inf,inf)
|
||||
|
||||
def test_interval_pow():
|
||||
assert mpi(3)**2 == mpi(9, 9)
|
||||
assert mpi(-3)**2 == mpi(9, 9)
|
||||
assert mpi(-3, 1)**2 == mpi(0, 9)
|
||||
assert mpi(-3, -1)**2 == mpi(1, 9)
|
||||
assert mpi(-3, -1)**3 == mpi(-27, -1)
|
||||
assert mpi(-3, 1)**3 == mpi(-27, 1)
|
||||
assert mpi(-2, 3)**2 == mpi(0, 9)
|
||||
assert mpi(-3, 2)**2 == mpi(0, 9)
|
||||
assert mpi(4) ** -1 == mpi(0.25, 0.25)
|
||||
assert mpi(-4) ** -1 == mpi(-0.25, -0.25)
|
||||
assert mpi(4) ** -2 == mpi(0.0625, 0.0625)
|
||||
assert mpi(-4) ** -2 == mpi(0.0625, 0.0625)
|
||||
assert mpi(0, 1) ** inf == mpi(0, 1)
|
||||
assert mpi(0, 1) ** -inf == mpi(1, inf)
|
||||
assert mpi(0, inf) ** inf == mpi(0, inf)
|
||||
assert mpi(0, inf) ** -inf == mpi(0, inf)
|
||||
assert mpi(1, inf) ** inf == mpi(1, inf)
|
||||
assert mpi(1, inf) ** -inf == mpi(0, 1)
|
||||
assert mpi(2, 3) ** 1 == mpi(2, 3)
|
||||
assert mpi(2, 3) ** 0 == 1
|
||||
assert mpi(1,3) ** mpi(2) == mpi(1,9)
|
||||
|
||||
def test_interval_sqrt():
|
||||
assert mpi(4) ** 0.5 == mpi(2)
|
||||
|
||||
def test_interval_div():
|
||||
assert mpi(0.5, 1) / mpi(-1, 0) == mpi(-inf, -0.5)
|
||||
assert mpi(0, 1) / mpi(0, 1) == mpi(0, inf)
|
||||
assert mpi(inf, inf) / mpi(inf, inf) == mpi(0, inf)
|
||||
assert mpi(inf, inf) / mpi(2, inf) == mpi(0, inf)
|
||||
assert mpi(inf, inf) / mpi(2, 2) == mpi(inf, inf)
|
||||
assert mpi(0, inf) / mpi(2, inf) == mpi(0, inf)
|
||||
assert mpi(0, inf) / mpi(2, 2) == mpi(0, inf)
|
||||
assert mpi(2, inf) / mpi(2, 2) == mpi(1, inf)
|
||||
assert mpi(2, inf) / mpi(2, inf) == mpi(0, inf)
|
||||
assert mpi(-4, 8) / mpi(1, inf) == mpi(-4, 8)
|
||||
assert mpi(-4, 8) / mpi(0.5, inf) == mpi(-8, 16)
|
||||
assert mpi(-inf, 8) / mpi(0.5, inf) == mpi(-inf, 16)
|
||||
assert mpi(-inf, inf) / mpi(0.5, inf) == mpi(-inf, inf)
|
||||
assert mpi(8, inf) / mpi(0.5, inf) == mpi(0, inf)
|
||||
assert mpi(-8, inf) / mpi(0.5, inf) == mpi(-16, inf)
|
||||
assert mpi(-4, 8) / mpi(inf, inf) == mpi(0, 0)
|
||||
assert mpi(0, 8) / mpi(inf, inf) == mpi(0, 0)
|
||||
assert mpi(0, 0) / mpi(inf, inf) == mpi(0, 0)
|
||||
assert mpi(-inf, 0) / mpi(inf, inf) == mpi(-inf, 0)
|
||||
assert mpi(-inf, 8) / mpi(inf, inf) == mpi(-inf, 0)
|
||||
assert mpi(-inf, inf) / mpi(inf, inf) == mpi(-inf, inf)
|
||||
assert mpi(-8, inf) / mpi(inf, inf) == mpi(0, inf)
|
||||
assert mpi(0, inf) / mpi(inf, inf) == mpi(0, inf)
|
||||
assert mpi(8, inf) / mpi(inf, inf) == mpi(0, inf)
|
||||
assert mpi(inf, inf) / mpi(inf, inf) == mpi(0, inf)
|
||||
assert mpi(-1, 2) / mpi(0, 1) == mpi(-inf, +inf)
|
||||
assert mpi(0, 1) / mpi(0, 1) == mpi(0.0, +inf)
|
||||
assert mpi(-1, 0) / mpi(0, 1) == mpi(-inf, 0.0)
|
||||
assert mpi(-0.5, -0.25) / mpi(0, 1) == mpi(-inf, -0.25)
|
||||
assert mpi(0.5, 1) / mpi(0, 1) == mpi(0.5, +inf)
|
||||
assert mpi(0.5, 4) / mpi(0, 1) == mpi(0.5, +inf)
|
||||
assert mpi(-1, -0.5) / mpi(0, 1) == mpi(-inf, -0.5)
|
||||
assert mpi(-4, -0.5) / mpi(0, 1) == mpi(-inf, -0.5)
|
||||
assert mpi(-1, 2) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(0, 1) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(-1, 0) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(-0.5, -0.25) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(0.5, 1) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(0.5, 4) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(-1, -0.5) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(-4, -0.5) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
||||
assert mpi(-1, 2) / mpi(-1, 0) == mpi(-inf, +inf)
|
||||
assert mpi(0, 1) / mpi(-1, 0) == mpi(-inf, 0.0)
|
||||
assert mpi(-1, 0) / mpi(-1, 0) == mpi(0.0, +inf)
|
||||
assert mpi(-0.5, -0.25) / mpi(-1, 0) == mpi(0.25, +inf)
|
||||
assert mpi(0.5, 1) / mpi(-1, 0) == mpi(-inf, -0.5)
|
||||
assert mpi(0.5, 4) / mpi(-1, 0) == mpi(-inf, -0.5)
|
||||
assert mpi(-1, -0.5) / mpi(-1, 0) == mpi(0.5, +inf)
|
||||
assert mpi(-4, -0.5) / mpi(-1, 0) == mpi(0.5, +inf)
|
||||
assert mpi(-1, 2) / mpi(0.5, 1) == mpi(-2.0, 4.0)
|
||||
assert mpi(0, 1) / mpi(0.5, 1) == mpi(0.0, 2.0)
|
||||
assert mpi(-1, 0) / mpi(0.5, 1) == mpi(-2.0, 0.0)
|
||||
assert mpi(-0.5, -0.25) / mpi(0.5, 1) == mpi(-1.0, -0.25)
|
||||
assert mpi(0.5, 1) / mpi(0.5, 1) == mpi(0.5, 2.0)
|
||||
assert mpi(0.5, 4) / mpi(0.5, 1) == mpi(0.5, 8.0)
|
||||
assert mpi(-1, -0.5) / mpi(0.5, 1) == mpi(-2.0, -0.5)
|
||||
assert mpi(-4, -0.5) / mpi(0.5, 1) == mpi(-8.0, -0.5)
|
||||
assert mpi(-1, 2) / mpi(-2, -0.5) == mpi(-4.0, 2.0)
|
||||
assert mpi(0, 1) / mpi(-2, -0.5) == mpi(-2.0, 0.0)
|
||||
assert mpi(-1, 0) / mpi(-2, -0.5) == mpi(0.0, 2.0)
|
||||
assert mpi(-0.5, -0.25) / mpi(-2, -0.5) == mpi(0.125, 1.0)
|
||||
assert mpi(0.5, 1) / mpi(-2, -0.5) == mpi(-2.0, -0.25)
|
||||
assert mpi(0.5, 4) / mpi(-2, -0.5) == mpi(-8.0, -0.25)
|
||||
assert mpi(-1, -0.5) / mpi(-2, -0.5) == mpi(0.25, 2.0)
|
||||
assert mpi(-4, -0.5) / mpi(-2, -0.5) == mpi(0.25, 8.0)
|
||||
# Should be undefined?
|
||||
assert mpi(0, 0) / mpi(0, 0) == mpi(-inf, inf)
|
||||
assert mpi(0, 0) / mpi(0, 1) == mpi(-inf, inf)
|
||||
|
||||
def test_interval_cos_sin():
|
||||
mp.dps = 15
|
||||
# Around 0
|
||||
assert cos(mpi(0)) == 1
|
||||
assert sin(mpi(0)) == 0
|
||||
assert cos(mpi(0,1)) == mpi(0.54030230586813965399, 1.0)
|
||||
assert sin(mpi(0,1)) == mpi(0, 0.8414709848078966159)
|
||||
assert cos(mpi(1,2)) == mpi(-0.4161468365471424069, 0.54030230586813976501)
|
||||
assert sin(mpi(1,2)) == mpi(0.84147098480789650488, 1.0)
|
||||
assert sin(mpi(1,2.5)) == mpi(0.59847214410395643824, 1.0)
|
||||
assert cos(mpi(-1, 1)) == mpi(0.54030230586813965399, 1.0)
|
||||
assert cos(mpi(-1, 0.5)) == mpi(0.54030230586813965399, 1.0)
|
||||
assert cos(mpi(-1, 1.5)) == mpi(0.070737201667702906405, 1.0)
|
||||
assert sin(mpi(-1,1)) == mpi(-0.8414709848078966159, 0.8414709848078966159)
|
||||
assert sin(mpi(-1,0.5)) == mpi(-0.8414709848078966159, 0.47942553860420300538)
|
||||
assert sin(mpi(-1,1e-100)) == mpi(-0.8414709848078966159, 1.00000000000000002e-100)
|
||||
assert sin(mpi(-2e-100,1e-100)) == mpi(-2.00000000000000004e-100, 1.00000000000000002e-100)
|
||||
# Same interval
|
||||
assert cos(mpi(2, 2.5)) == mpi(-0.80114361554693380718, -0.41614683654714235139)
|
||||
assert cos(mpi(3.5, 4)) == mpi(-0.93645668729079634129, -0.65364362086361182946)
|
||||
assert cos(mpi(5, 5.5)) == mpi(0.28366218546322624627, 0.70866977429126010168)
|
||||
assert sin(mpi(2, 2.5)) == mpi(0.59847214410395654927, 0.90929742682568170942)
|
||||
assert sin(mpi(3.5, 4)) == mpi(-0.75680249530792831347, -0.35078322768961983646)
|
||||
assert sin(mpi(5, 5.5)) == mpi(-0.95892427466313856499, -0.70554032557039181306)
|
||||
# Higher roots
|
||||
mp.dps = 55
|
||||
w = 4*10**50 + mpf(0.5)
|
||||
for p in [15, 40, 80]:
|
||||
mp.dps = p
|
||||
assert 0 in sin(4*mpi(pi))
|
||||
assert 0 in sin(4*10**50*mpi(pi))
|
||||
assert 0 in cos((4+0.5)*mpi(pi))
|
||||
assert 0 in cos(w*mpi(pi))
|
||||
assert 1 in cos(4*mpi(pi))
|
||||
assert 1 in cos(4*10**50*mpi(pi))
|
||||
mp.dps = 15
|
||||
assert cos(mpi(2,inf)) == mpi(-1,1)
|
||||
assert sin(mpi(2,inf)) == mpi(-1,1)
|
||||
assert cos(mpi(-inf,2)) == mpi(-1,1)
|
||||
assert sin(mpi(-inf,2)) == mpi(-1,1)
|
||||
u = tan(mpi(0.5,1))
|
||||
assert u.a.ae(tan(0.5))
|
||||
assert u.b.ae(tan(1))
|
||||
v = cot(mpi(0.5,1))
|
||||
assert v.a.ae(cot(1))
|
||||
assert v.b.ae(cot(0.5))
|
||||
|
||||
def test_mpi_to_str():
|
||||
mp.dps = 30
|
||||
x = mpi(1, 2)
|
||||
# FIXME: error_dps should not be necessary
|
||||
assert mpi_to_str(x, mode='plusminus', error_dps=6) == '1.5 +- 0.5'
|
||||
assert mpi_to_str(x, mode='plusminus', use_spaces=False, error_dps=6
|
||||
) == '1.5+-0.5'
|
||||
assert mpi_to_str(x, mode='percent') == '1.5 (33.33%)'
|
||||
assert mpi_to_str(x, mode='brackets', use_spaces=False) == '[1.0,2.0]'
|
||||
assert mpi_to_str(x, mode='brackets' , brackets=('<', '>')) == '<1.0, 2.0>'
|
||||
x = mpi('5.2582327113062393041', '5.2582327113062749951')
|
||||
assert (mpi_to_str(x, mode='diff') ==
|
||||
'5.2582327113062[393041, 749951]')
|
||||
assert (mpi_to_str(cos(mpi(1)), mode='diff', use_spaces=False) ==
|
||||
'0.54030230586813971740093660744[2955,3053]')
|
||||
assert (mpi_to_str(mpi('1e123', '1e129'), mode='diff') ==
|
||||
'[1.0e+123, 1.0e+129]')
|
||||
assert (mpi_to_str(exp(mpi('5000.1')), mode='diff') ==
|
||||
'3.2797365856787867069110487[0926, 1191]e+2171')
|
||||
|
||||
def test_mpi_from_str():
|
||||
assert mpi_from_str('1.5 +- 0.5') == mpi(mpf('1.0'), mpf('2.0'))
|
||||
assert (mpi_from_str('1.5 (33.33333333333333333333333333333%)') ==
|
||||
mpi(mpf(1), mpf(2)))
|
||||
assert mpi_from_str('[1, 2]') == mpi(1, 2)
|
||||
assert mpi_from_str('1[2, 3]') == mpi(12, 13)
|
||||
assert mpi_from_str('1.[23,46]e-8') == mpi('1.23e-8', '1.46e-8')
|
||||
assert mpi_from_str('12[3.4,5.9]e4') == mpi('123.4e+4', '125.9e4')
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
# TODO: don't use round
|
||||
|
||||
from __future__ import division
|
||||
|
||||
from mpmath import *
|
||||
|
||||
# XXX: these shouldn't be visible(?)
|
||||
LU_decomp = mp.LU_decomp
|
||||
L_solve = mp.L_solve
|
||||
U_solve = mp.U_solve
|
||||
householder = mp.householder
|
||||
improve_solution = mp.improve_solution
|
||||
|
||||
A1 = matrix([[3, 1, 6],
|
||||
[2, 1, 3],
|
||||
[1, 1, 1]])
|
||||
b1 = [2, 7, 4]
|
||||
|
||||
A2 = matrix([[ 2, -1, -1, 2],
|
||||
[ 6, -2, 3, -1],
|
||||
[-4, 2, 3, -2],
|
||||
[ 2, 0, 4, -3]])
|
||||
b2 = [3, -3, -2, -1]
|
||||
|
||||
A3 = matrix([[ 1, 0, -1, -1, 0],
|
||||
[ 0, 1, 1, 0, -1],
|
||||
[ 4, -5, 2, 0, 0],
|
||||
[ 0, 0, -2, 9,-12],
|
||||
[ 0, 5, 0, 0, 12]])
|
||||
b3 = [0, 0, 0, 0, 50]
|
||||
|
||||
A4 = matrix([[10.235, -4.56, 0., -0.035, 5.67],
|
||||
[-2.463, 1.27, 3.97, -8.63, 1.08],
|
||||
[-6.58, 0.86, -0.257, 9.32, -43.6 ],
|
||||
[ 9.83, 7.39, -17.25, 0.036, 24.86],
|
||||
[-9.31, 34.9, 78.56, 1.07, 65.8 ]])
|
||||
b4 = [8.95, 20.54, 7.42, 5.60, 58.43]
|
||||
|
||||
A5 = matrix([[ 1, 2, -4],
|
||||
[-2, -3, 5],
|
||||
[ 3, 5, -8]])
|
||||
|
||||
A6 = matrix([[ 1.377360, 2.481400, 5.359190],
|
||||
[ 2.679280, -1.229560, 25.560210],
|
||||
[-1.225280+1.e6, 9.910180, -35.049900-1.e6]])
|
||||
b6 = [23.500000, -15.760000, 2.340000]
|
||||
|
||||
A7 = matrix([[1, -0.5],
|
||||
[2, 1],
|
||||
[-2, 6]])
|
||||
b7 = [3, 2, -4]
|
||||
|
||||
A8 = matrix([[1, 2, 3],
|
||||
[-1, 0, 1],
|
||||
[-1, -2, -1],
|
||||
[1, 0, -1]])
|
||||
b8 = [1, 2, 3, 4]
|
||||
|
||||
A9 = matrix([[ 4, 2, -2],
|
||||
[ 2, 5, -4],
|
||||
[-2, -4, 5.5]])
|
||||
b9 = [10, 16, -15.5]
|
||||
|
||||
A10 = matrix([[1.0 + 1.0j, 2.0, 2.0],
|
||||
[4.0, 5.0, 6.0],
|
||||
[7.0, 8.0, 9.0]])
|
||||
b10 = [1.0, 1.0 + 1.0j, 1.0]
|
||||
|
||||
|
||||
def test_LU_decomp():
|
||||
A = A3.copy()
|
||||
b = b3
|
||||
A, p = LU_decomp(A)
|
||||
y = L_solve(A, b, p)
|
||||
x = U_solve(A, y)
|
||||
assert p == [2, 1, 2, 3]
|
||||
assert [round(i, 14) for i in x] == [3.78953107960742, 2.9989094874591098,
|
||||
-0.081788440567070006, 3.8713195201744801, 2.9171210468920399]
|
||||
A = A4.copy()
|
||||
b = b4
|
||||
A, p = LU_decomp(A)
|
||||
y = L_solve(A, b, p)
|
||||
x = U_solve(A, y)
|
||||
assert p == [0, 3, 4, 3]
|
||||
assert [round(i, 14) for i in x] == [2.6383625899619201, 2.6643834462368399,
|
||||
0.79208015947958998, -2.5088376454101899, -1.0567657691375001]
|
||||
A = randmatrix(3)
|
||||
bak = A.copy()
|
||||
LU_decomp(A, overwrite=1)
|
||||
assert A != bak
|
||||
|
||||
def test_inverse():
|
||||
for A in [A1, A2, A5]:
|
||||
inv = inverse(A)
|
||||
assert mnorm(A*inv - eye(A.rows), 1) < 1.e-14
|
||||
|
||||
def test_householder():
|
||||
mp.dps = 15
|
||||
A, b = A8, b8
|
||||
H, p, x, r = householder(extend(A, b))
|
||||
assert H == matrix(
|
||||
[[mpf('3.0'), mpf('-2.0'), mpf('-1.0'), 0],
|
||||
[-1.0,mpf('3.333333333333333'),mpf('-2.9999999999999991'),mpf('2.0')],
|
||||
[-1.0, mpf('-0.66666666666666674'),mpf('2.8142135623730948'),
|
||||
mpf('-2.8284271247461898')],
|
||||
[1.0, mpf('-1.3333333333333333'),mpf('-0.20000000000000018'),
|
||||
mpf('4.2426406871192857')]])
|
||||
assert p == [-2, -2, mpf('-1.4142135623730949')]
|
||||
assert round(norm(r, 2), 10) == 4.2426406870999998
|
||||
|
||||
y = [102.102, 58.344, 36.463, 24.310, 17.017, 12.376, 9.282, 7.140, 5.610,
|
||||
4.488, 3.6465, 3.003]
|
||||
|
||||
def coeff(n):
|
||||
# similiar to Hilbert matrix
|
||||
A = []
|
||||
for i in xrange(1, 13):
|
||||
A.append([1. / (i + j - 1) for j in xrange(1, n + 1)])
|
||||
return matrix(A)
|
||||
|
||||
residuals = []
|
||||
refres = []
|
||||
for n in xrange(2, 7):
|
||||
A = coeff(n)
|
||||
H, p, x, r = householder(extend(A, y))
|
||||
x = matrix(x)
|
||||
y = matrix(y)
|
||||
residuals.append(norm(r, 2))
|
||||
refres.append(norm(residual(A, x, y), 2))
|
||||
assert [round(res, 10) for res in residuals] == [15.1733888877,
|
||||
0.82378073210000002, 0.302645887, 0.0260109244,
|
||||
0.00058653999999999998]
|
||||
assert norm(matrix(residuals) - matrix(refres), inf) < 1.e-13
|
||||
|
||||
def test_factorization():
|
||||
A = randmatrix(5)
|
||||
P, L, U = lu(A)
|
||||
assert mnorm(P*A - L*U, 1) < 1.e-15
|
||||
|
||||
def test_solve():
|
||||
assert norm(residual(A6, lu_solve(A6, b6), b6), inf) < 1.e-10
|
||||
assert norm(residual(A7, lu_solve(A7, b7), b7), inf) < 1.5
|
||||
assert norm(residual(A8, lu_solve(A8, b8), b8), inf) <= 3 + 1.e-10
|
||||
assert norm(residual(A6, qr_solve(A6, b6)[0], b6), inf) < 1.e-10
|
||||
assert norm(residual(A7, qr_solve(A7, b7)[0], b7), inf) < 1.5
|
||||
assert norm(residual(A8, qr_solve(A8, b8)[0], b8), 2) <= 4.3
|
||||
assert norm(residual(A10, lu_solve(A10, b10), b10), 2) < 1.e-10
|
||||
assert norm(residual(A10, qr_solve(A10, b10)[0], b10), 2) < 1.e-10
|
||||
|
||||
def test_solve_overdet_complex():
|
||||
A = matrix([[1, 2j], [3, 4j], [5, 6]])
|
||||
b = matrix([1 + j, 2, -j])
|
||||
assert norm(residual(A, lu_solve(A, b), b)) < 1.0208
|
||||
|
||||
def test_singular():
|
||||
mp.dps = 15
|
||||
A = [[5.6, 1.2], [7./15, .1]]
|
||||
B = repr(zeros(2))
|
||||
b = [1, 2]
|
||||
def _assert_ZeroDivisionError(statement):
|
||||
try:
|
||||
eval(statement)
|
||||
assert False
|
||||
except (ZeroDivisionError, ValueError):
|
||||
pass
|
||||
for i in ['lu_solve(%s, %s)' % (A, b), 'lu_solve(%s, %s)' % (B, b),
|
||||
'qr_solve(%s, %s)' % (A, b), 'qr_solve(%s, %s)' % (B, b)]:
|
||||
_assert_ZeroDivisionError(i)
|
||||
|
||||
def test_cholesky():
|
||||
assert fp.cholesky(fp.matrix(A9)) == fp.matrix([[2, 0, 0], [1, 2, 0], [-1, -3/2, 3/2]])
|
||||
x = fp.cholesky_solve(A9, b9)
|
||||
assert fp.norm(fp.residual(A9, x, b9), fp.inf) == 0
|
||||
|
||||
def test_det():
|
||||
assert det(A1) == 1
|
||||
assert round(det(A2), 14) == 8
|
||||
assert round(det(A3)) == 1834
|
||||
assert round(det(A4)) == 4443376
|
||||
assert det(A5) == 1
|
||||
assert round(det(A6)) == 78356463
|
||||
assert det(zeros(3)) == 0
|
||||
|
||||
def test_cond():
|
||||
mp.dps = 15
|
||||
A = matrix([[1.2969, 0.8648], [0.2161, 0.1441]])
|
||||
assert cond(A, lambda x: mnorm(x,1)) == mpf('327065209.73817754')
|
||||
assert cond(A, lambda x: mnorm(x,inf)) == mpf('327065209.73817754')
|
||||
assert cond(A, lambda x: mnorm(x,'F')) == mpf('249729266.80008656')
|
||||
|
||||
@extradps(50)
|
||||
def test_precision():
|
||||
A = randmatrix(10, 10)
|
||||
assert mnorm(inverse(inverse(A)) - A, 1) < 1.e-45
|
||||
|
||||
def test_interval_matrix():
|
||||
a = matrix([['0.1','0.3','1.0'],['7.1','5.5','4.8'],['3.2','4.4','5.6']],
|
||||
force_type=mpi)
|
||||
b = matrix(['4','0.6','0.5'], force_type=mpi)
|
||||
c = lu_solve(a, b)
|
||||
assert c[0].delta < 1e-13
|
||||
assert c[1].delta < 1e-13
|
||||
assert c[2].delta < 1e-13
|
||||
assert 5.25823271130625686059275 in c[0]
|
||||
assert -13.155049396267837541163 in c[1]
|
||||
assert 7.42069154774972557628979 in c[2]
|
||||
|
||||
def test_LU_cache():
|
||||
A = randmatrix(3)
|
||||
LU = LU_decomp(A)
|
||||
assert A._LU == LU_decomp(A)
|
||||
A[0,0] = -1000
|
||||
assert A._LU is None
|
||||
|
||||
def test_improve_solution():
|
||||
A = randmatrix(5, min=1e-20, max=1e20)
|
||||
b = randmatrix(5, 1, min=-1000, max=1000)
|
||||
x1 = lu_solve(A, b) + randmatrix(5, 1, min=-1e-5, max=1.e-5)
|
||||
x2 = improve_solution(A, x1, b)
|
||||
assert norm(residual(A, x2, b), 2) < norm(residual(A, x1, b), 2)
|
||||
|
||||
def test_exp_pade():
|
||||
for i in range(3):
|
||||
dps = 15
|
||||
extra = 5
|
||||
mp.dps = dps + extra
|
||||
dm = 0
|
||||
while not dm:
|
||||
m = randmatrix(3)
|
||||
dm = det(m)
|
||||
m = m/dm
|
||||
a = diag([1,2,3])
|
||||
a1 = m**-1 * a * m
|
||||
mp.dps = dps
|
||||
e1 = expm(a1, method='pade')
|
||||
mp.dps = dps + extra
|
||||
e2 = m * a1 * m**-1
|
||||
d = e2 - a
|
||||
#print d
|
||||
mp.dps = dps
|
||||
assert norm(d, inf).ae(0)
|
||||
mp.dps = 15
|
||||
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_matrix_basic():
|
||||
A1 = matrix(3)
|
||||
for i in xrange(3):
|
||||
A1[i,i] = 1
|
||||
assert A1 == eye(3)
|
||||
assert A1 == matrix(A1)
|
||||
A2 = matrix(3, 2)
|
||||
assert not A2._matrix__data
|
||||
A3 = matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
assert list(A3) == range(1, 10)
|
||||
A3[1,1] = 0
|
||||
assert not (1, 1) in A3._matrix__data
|
||||
A4 = matrix([[1, 2, 3], [4, 5, 6]])
|
||||
A5 = matrix([[6, -1], [3, 2], [0, -3]])
|
||||
assert A4 * A5 == matrix([[12, -6], [39, -12]])
|
||||
assert A1 * A3 == A3 * A1 == A3
|
||||
try:
|
||||
A2 * A2
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
l = [[10, 20, 30], [40, 0, 60], [70, 80, 90]]
|
||||
A6 = matrix(l)
|
||||
assert A6.tolist() == l
|
||||
assert A6 == eval(repr(A6))
|
||||
A6 = matrix(A6, force_type=float)
|
||||
assert A6 == eval(repr(A6))
|
||||
assert A6*1j == eval(repr(A6*1j))
|
||||
assert A3 * 10 == 10 * A3 == A6
|
||||
assert A2.rows == 3
|
||||
assert A2.cols == 2
|
||||
A3.rows = 2
|
||||
A3.cols = 2
|
||||
assert len(A3._matrix__data) == 3
|
||||
assert A4 + A4 == 2*A4
|
||||
try:
|
||||
A4 + A2
|
||||
except ValueError:
|
||||
pass
|
||||
assert sum(A1 - A1) == 0
|
||||
A7 = matrix([[1, 2], [3, 4], [5, 6], [7, 8]])
|
||||
x = matrix([10, -10])
|
||||
assert A7*x == matrix([-10, -10, -10, -10])
|
||||
A8 = ones(5)
|
||||
assert sum((A8 + 1) - (2 - zeros(5))) == 0
|
||||
assert (1 + ones(4)) / 2 - 1 == zeros(4)
|
||||
assert eye(3)**10 == eye(3)
|
||||
try:
|
||||
A7**2
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
A9 = randmatrix(3)
|
||||
A10 = matrix(A9)
|
||||
A9[0,0] = -100
|
||||
assert A9 != A10
|
||||
A11 = matrix(randmatrix(2, 3), force_type=mpi)
|
||||
for a in A11:
|
||||
assert isinstance(a, mpi)
|
||||
assert nstr(A9)
|
||||
|
||||
def test_matrix_power():
|
||||
A = matrix([[1, 2], [3, 4]])
|
||||
assert A**2 == A*A
|
||||
assert A**3 == A*A*A
|
||||
assert A**-1 == inverse(A)
|
||||
assert A**-2 == inverse(A*A)
|
||||
|
||||
def test_matrix_transform():
|
||||
A = matrix([[1, 2], [3, 4], [5, 6]])
|
||||
assert A.T == A.transpose() == matrix([[1, 3, 5], [2, 4, 6]])
|
||||
swap_row(A, 1, 2)
|
||||
assert A == matrix([[1, 2], [5, 6], [3, 4]])
|
||||
l = [1, 2]
|
||||
swap_row(l, 0, 1)
|
||||
assert l == [2, 1]
|
||||
assert extend(eye(3), [1,2,3]) == matrix([[1,0,0,1],[0,1,0,2],[0,0,1,3]])
|
||||
|
||||
def test_matrix_conjugate():
|
||||
A = matrix([[1 + j, 0], [2, j]])
|
||||
assert A.conjugate() == matrix([[mpc(1, -1), 0], [2, mpc(0, -1)]])
|
||||
assert A.transpose_conj() == A.H == matrix([[mpc(1, -1), 2],
|
||||
[0, mpc(0, -1)]])
|
||||
|
||||
def test_matrix_creation():
|
||||
assert diag([1, 2, 3]) == matrix([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
|
||||
A1 = ones(2, 3)
|
||||
assert A1.rows == 2 and A1.cols == 3
|
||||
for a in A1:
|
||||
assert a == 1
|
||||
A2 = zeros(3, 2)
|
||||
assert A2.rows == 3 and A2.cols == 2
|
||||
for a in A2:
|
||||
assert a == 0
|
||||
assert randmatrix(10) != randmatrix(10)
|
||||
one = mpf(1)
|
||||
assert hilbert(3) == matrix([[one, one/2, one/3],
|
||||
[one/2, one/3, one/4],
|
||||
[one/3, one/4, one/5]])
|
||||
|
||||
def test_norms():
|
||||
# matrix norms
|
||||
A = matrix([[1, -2], [-3, -1], [2, 1]])
|
||||
assert mnorm(A,1) == 6
|
||||
assert mnorm(A,inf) == 4
|
||||
assert mnorm(A,'F') == sqrt(20)
|
||||
# vector norms
|
||||
assert norm(-3) == 3
|
||||
x = [1, -2, 7, -12]
|
||||
assert norm(x, 1) == 22
|
||||
assert round(norm(x, 2), 10) == 14.0712472795
|
||||
assert round(norm(x, 10), 10) == 12.0054633727
|
||||
assert norm(x, inf) == 12
|
||||
|
||||
def test_vector():
|
||||
x = matrix([0, 1, 2, 3, 4])
|
||||
assert x == matrix([[0], [1], [2], [3], [4]])
|
||||
assert x[3] == 3
|
||||
assert len(x._matrix__data) == 4
|
||||
assert list(x) == range(5)
|
||||
x[0] = -10
|
||||
x[4] = 0
|
||||
assert x[0] == -10
|
||||
assert len(x) == len(x.T) == 5
|
||||
assert x.T*x == matrix([[114]])
|
||||
|
||||
def test_matrix_copy():
|
||||
A = ones(6)
|
||||
B = A.copy()
|
||||
assert A == B
|
||||
B[0,0] = 0
|
||||
assert A != B
|
||||
|
||||
def test_matrix_numpy():
|
||||
try:
|
||||
import numpy
|
||||
except ImportError:
|
||||
return
|
||||
l = [[1, 2], [3, 4], [5, 6]]
|
||||
a = numpy.matrix(l)
|
||||
assert matrix(l) == matrix(a)
|
||||
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
from mpmath.libmp import *
|
||||
from mpmath import *
|
||||
import random
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Low-level tests
|
||||
#
|
||||
|
||||
# Advanced rounding test
|
||||
def test_add_rounding():
|
||||
mp.dps = 15
|
||||
a = from_float(1e-50)
|
||||
assert mpf_sub(mpf_add(fone, a, 53, round_up), fone, 53, round_up) == from_float(2.2204460492503131e-16)
|
||||
assert mpf_sub(fone, a, 53, round_up) == fone
|
||||
assert mpf_sub(fone, mpf_sub(fone, a, 53, round_down), 53, round_down) == from_float(1.1102230246251565e-16)
|
||||
assert mpf_add(fone, a, 53, round_down) == fone
|
||||
|
||||
def test_almost_equal():
|
||||
assert mpf(1.2).ae(mpf(1.20000001), 1e-7)
|
||||
assert not mpf(1.2).ae(mpf(1.20000001), 1e-9)
|
||||
assert not mpf(-0.7818314824680298).ae(mpf(-0.774695868667929))
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Test basic arithmetic
|
||||
#
|
||||
|
||||
# Test that integer arithmetic is exact
|
||||
def test_aintegers():
|
||||
# XXX: re-fix this so that all operations are tested with all rounding modes
|
||||
random.seed(0)
|
||||
for prec in [6, 10, 25, 40, 100, 250, 725]:
|
||||
for rounding in ['d', 'u', 'f', 'c', 'n']:
|
||||
mp.dps = prec
|
||||
M = 10**(prec-2)
|
||||
M2 = 10**(prec//2-2)
|
||||
for i in range(10):
|
||||
a = random.randint(-M, M)
|
||||
b = random.randint(-M, M)
|
||||
assert mpf(a, rounding=rounding) == a
|
||||
assert int(mpf(a, rounding=rounding)) == a
|
||||
assert int(mpf(str(a), rounding=rounding)) == a
|
||||
assert mpf(a) + mpf(b) == a + b
|
||||
assert mpf(a) - mpf(b) == a - b
|
||||
assert -mpf(a) == -a
|
||||
a = random.randint(-M2, M2)
|
||||
b = random.randint(-M2, M2)
|
||||
assert mpf(a) * mpf(b) == a*b
|
||||
assert mpf_mul(from_int(a), from_int(b), mp.prec, rounding) == from_int(a*b)
|
||||
mp.dps = 15
|
||||
|
||||
def test_odd_int_bug():
|
||||
assert to_int(from_int(3), round_nearest) == 3
|
||||
|
||||
def test_str_1000_digits():
|
||||
mp.dps = 1001
|
||||
# last digit may be wrong
|
||||
assert str(mpf(2)**0.5)[-10:-1] == '9518488472'[:9]
|
||||
assert str(pi)[-10:-1] == '2164201989'[:9]
|
||||
mp.dps = 15
|
||||
|
||||
def test_str_10000_digits():
|
||||
mp.dps = 10001
|
||||
# last digit may be wrong
|
||||
assert str(mpf(2)**0.5)[-10:-1] == '5873258351'[:9]
|
||||
assert str(pi)[-10:-1] == '5256375678'[:9]
|
||||
mp.dps = 15
|
||||
|
||||
def test_monitor():
|
||||
f = lambda x: x**2
|
||||
a = []
|
||||
b = []
|
||||
g = monitor(f, a.append, b.append)
|
||||
assert g(3) == 9
|
||||
assert g(4) == 16
|
||||
assert a[0] == ((3,), {})
|
||||
assert b[0] == 9
|
||||
|
||||
def test_nint_distance():
|
||||
nint_distance(mpf(-3)) == (-3, -inf)
|
||||
nint_distance(mpc(-3)) == (-3, -inf)
|
||||
nint_distance(mpf(-3.1)) == (-3, -3)
|
||||
nint_distance(mpf(-3.01)) == (-3, -6)
|
||||
nint_distance(mpf(-3.001)) == (-3, -9)
|
||||
nint_distance(mpf(-3.0001)) == (-3, -13)
|
||||
nint_distance(mpf(-2.9)) == (-3, -3)
|
||||
nint_distance(mpf(-2.99)) == (-3, -6)
|
||||
nint_distance(mpf(-2.999)) == (-3, -9)
|
||||
nint_distance(mpf(-2.9999)) == (-3, -13)
|
||||
nint_distance(mpc(-3+0.1j)) == (-3, -3)
|
||||
nint_distance(mpc(-3+0.01j)) == (-3, -6)
|
||||
nint_distance(mpc(-3.1+0.1j)) == (-3, -3)
|
||||
nint_distance(mpc(-3.01+0.01j)) == (-3, -6)
|
||||
nint_distance(mpc(-3.001+0.001j)) == (-3, -9)
|
||||
nint_distance(mpf(0)) == (0, -inf)
|
||||
nint_distance(mpf(0.01)) == (0, -6)
|
||||
nint_distance(mpf('1e-100')) == (0, -332)
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
#from mpmath.calculus import ODE_step_euler, ODE_step_rk4, odeint, arange
|
||||
from mpmath import odefun, cos, sin, mpf, sinc, mp
|
||||
|
||||
'''
|
||||
solvers = [ODE_step_euler, ODE_step_rk4]
|
||||
|
||||
def test_ode1():
|
||||
"""
|
||||
Let's solve:
|
||||
|
||||
x'' + w**2 * x = 0
|
||||
|
||||
i.e. x1 = x, x2 = x1':
|
||||
|
||||
x1' = x2
|
||||
x2' = -x1
|
||||
"""
|
||||
def derivs((x1, x2), t):
|
||||
return x2, -x1
|
||||
|
||||
for solver in solvers:
|
||||
t = arange(0, 3.1415926, 0.005)
|
||||
sol = odeint(derivs, (0., 1.), t, solver)
|
||||
x1 = [a[0] for a in sol]
|
||||
x2 = [a[1] for a in sol]
|
||||
# the result is x1 = sin(t), x2 = cos(t)
|
||||
# let's just check the end points for t = pi
|
||||
assert abs(x1[-1]) < 1e-2
|
||||
assert abs(x2[-1] - (-1)) < 1e-2
|
||||
|
||||
def test_ode2():
|
||||
"""
|
||||
Let's solve:
|
||||
|
||||
x' - x = 0
|
||||
|
||||
i.e. x = exp(x)
|
||||
|
||||
"""
|
||||
def derivs((x), t):
|
||||
return x
|
||||
|
||||
for solver in solvers:
|
||||
t = arange(0, 1, 1e-3)
|
||||
sol = odeint(derivs, (1.,), t, solver)
|
||||
x = [a[0] for a in sol]
|
||||
# the result is x = exp(t)
|
||||
# let's just check the end point for t = 1, i.e. x = e
|
||||
assert abs(x[-1] - 2.718281828) < 1e-2
|
||||
'''
|
||||
|
||||
def test_odefun_rational():
|
||||
mp.dps = 15
|
||||
# A rational function
|
||||
f = lambda t: 1/(1+mpf(t)**2)
|
||||
g = odefun(lambda x, y: [-2*x*y[0]**2], 0, [f(0)])
|
||||
assert f(2).ae(g(2)[0])
|
||||
|
||||
def test_odefun_sinc_large():
|
||||
mp.dps = 15
|
||||
# Sinc function; test for large x
|
||||
f = sinc
|
||||
g = odefun(lambda x, y: [(cos(x)-y[0])/x], 1, [f(1)], tol=0.01, degree=5)
|
||||
assert abs(f(100) - g(100)[0])/f(100) < 0.01
|
||||
|
||||
def test_odefun_harmonic():
|
||||
mp.dps = 15
|
||||
# Harmonic oscillator
|
||||
f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
|
||||
for x in [0, 1, 2.5, 8, 3.7]: # we go back to 3.7 to check caching
|
||||
c, s = f(x)
|
||||
assert c.ae(cos(x))
|
||||
assert s.ae(sin(x))
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
import pickle
|
||||
|
||||
from mpmath import *
|
||||
|
||||
def pickler(obj):
|
||||
fn = tempfile.mktemp()
|
||||
|
||||
f = open(fn, 'wb')
|
||||
pickle.dump(obj, f)
|
||||
f.close()
|
||||
|
||||
f = open(fn, 'rb')
|
||||
obj2 = pickle.load(f)
|
||||
f.close()
|
||||
os.remove(fn)
|
||||
|
||||
return obj2
|
||||
|
||||
def test_pickle():
|
||||
|
||||
obj = mpf('0.5')
|
||||
assert obj == pickler(obj)
|
||||
|
||||
obj = mpc('0.5','0.2')
|
||||
assert obj == pickler(obj)
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
from mpmath import *
|
||||
from mpmath.libmp import *
|
||||
|
||||
import random
|
||||
|
||||
def test_fractional_pow():
|
||||
assert mpf(16) ** 2.5 == 1024
|
||||
assert mpf(64) ** 0.5 == 8
|
||||
assert mpf(64) ** -0.5 == 0.125
|
||||
assert mpf(16) ** -2.5 == 0.0009765625
|
||||
assert (mpf(10) ** 0.5).ae(3.1622776601683791)
|
||||
assert (mpf(10) ** 2.5).ae(316.2277660168379)
|
||||
assert (mpf(10) ** -0.5).ae(0.31622776601683794)
|
||||
assert (mpf(10) ** -2.5).ae(0.0031622776601683794)
|
||||
assert (mpf(10) ** 0.3).ae(1.9952623149688795)
|
||||
assert (mpf(10) ** -0.3).ae(0.50118723362727224)
|
||||
|
||||
def test_pow_integer_direction():
|
||||
"""
|
||||
Test that inexact integer powers are rounded in the right
|
||||
direction.
|
||||
"""
|
||||
random.seed(1234)
|
||||
for prec in [10, 53, 200]:
|
||||
for i in range(50):
|
||||
a = random.randint(1<<(prec-1), 1<<prec)
|
||||
b = random.randint(2, 100)
|
||||
ab = a**b
|
||||
# note: could actually be exact, but that's very unlikely!
|
||||
assert to_int(mpf_pow(from_int(a), from_int(b), prec, round_down)) < ab
|
||||
assert to_int(mpf_pow(from_int(a), from_int(b), prec, round_up)) > ab
|
||||
|
||||
|
||||
def test_pow_epsilon_rounding():
|
||||
"""
|
||||
Stress test directed rounding for powers with integer exponents.
|
||||
Basically, we look at the following cases:
|
||||
|
||||
>>> 1.0001 ** -5
|
||||
0.99950014996500702
|
||||
>>> 0.9999 ** -5
|
||||
1.000500150035007
|
||||
>>> (-1.0001) ** -5
|
||||
-0.99950014996500702
|
||||
>>> (-0.9999) ** -5
|
||||
-1.000500150035007
|
||||
|
||||
>>> 1.0001 ** -6
|
||||
0.99940020994401269
|
||||
>>> 0.9999 ** -6
|
||||
1.0006002100560125
|
||||
>>> (-1.0001) ** -6
|
||||
0.99940020994401269
|
||||
>>> (-0.9999) ** -6
|
||||
1.0006002100560125
|
||||
|
||||
etc.
|
||||
|
||||
We run the tests with values a very small epsilon away from 1:
|
||||
small enough that the result is indistinguishable from 1 when
|
||||
rounded to nearest at the output precision. We check that the
|
||||
result is not erroneously rounded to 1 in cases where the
|
||||
rounding should be done strictly away from 1.
|
||||
"""
|
||||
|
||||
def powr(x, n, r):
|
||||
return make_mpf(mpf_pow_int(x._mpf_, n, mp.prec, r))
|
||||
|
||||
for (inprec, outprec) in [(100, 20), (5000, 3000)]:
|
||||
|
||||
mp.prec = inprec
|
||||
|
||||
pos10001 = mpf(1) + mpf(2)**(-inprec+5)
|
||||
pos09999 = mpf(1) - mpf(2)**(-inprec+5)
|
||||
neg10001 = -pos10001
|
||||
neg09999 = -pos09999
|
||||
|
||||
mp.prec = outprec
|
||||
r = round_up
|
||||
assert powr(pos10001, 5, r) > 1
|
||||
assert powr(pos09999, 5, r) == 1
|
||||
assert powr(neg10001, 5, r) < -1
|
||||
assert powr(neg09999, 5, r) == -1
|
||||
assert powr(pos10001, 6, r) > 1
|
||||
assert powr(pos09999, 6, r) == 1
|
||||
assert powr(neg10001, 6, r) > 1
|
||||
assert powr(neg09999, 6, r) == 1
|
||||
|
||||
assert powr(pos10001, -5, r) == 1
|
||||
assert powr(pos09999, -5, r) > 1
|
||||
assert powr(neg10001, -5, r) == -1
|
||||
assert powr(neg09999, -5, r) < -1
|
||||
assert powr(pos10001, -6, r) == 1
|
||||
assert powr(pos09999, -6, r) > 1
|
||||
assert powr(neg10001, -6, r) == 1
|
||||
assert powr(neg09999, -6, r) > 1
|
||||
|
||||
r = round_down
|
||||
assert powr(pos10001, 5, r) == 1
|
||||
assert powr(pos09999, 5, r) < 1
|
||||
assert powr(neg10001, 5, r) == -1
|
||||
assert powr(neg09999, 5, r) > -1
|
||||
assert powr(pos10001, 6, r) == 1
|
||||
assert powr(pos09999, 6, r) < 1
|
||||
assert powr(neg10001, 6, r) == 1
|
||||
assert powr(neg09999, 6, r) < 1
|
||||
|
||||
assert powr(pos10001, -5, r) < 1
|
||||
assert powr(pos09999, -5, r) == 1
|
||||
assert powr(neg10001, -5, r) > -1
|
||||
assert powr(neg09999, -5, r) == -1
|
||||
assert powr(pos10001, -6, r) < 1
|
||||
assert powr(pos09999, -6, r) == 1
|
||||
assert powr(neg10001, -6, r) < 1
|
||||
assert powr(neg09999, -6, r) == 1
|
||||
|
||||
r = round_ceiling
|
||||
assert powr(pos10001, 5, r) > 1
|
||||
assert powr(pos09999, 5, r) == 1
|
||||
assert powr(neg10001, 5, r) == -1
|
||||
assert powr(neg09999, 5, r) > -1
|
||||
assert powr(pos10001, 6, r) > 1
|
||||
assert powr(pos09999, 6, r) == 1
|
||||
assert powr(neg10001, 6, r) > 1
|
||||
assert powr(neg09999, 6, r) == 1
|
||||
|
||||
assert powr(pos10001, -5, r) == 1
|
||||
assert powr(pos09999, -5, r) > 1
|
||||
assert powr(neg10001, -5, r) > -1
|
||||
assert powr(neg09999, -5, r) == -1
|
||||
assert powr(pos10001, -6, r) == 1
|
||||
assert powr(pos09999, -6, r) > 1
|
||||
assert powr(neg10001, -6, r) == 1
|
||||
assert powr(neg09999, -6, r) > 1
|
||||
|
||||
r = round_floor
|
||||
assert powr(pos10001, 5, r) == 1
|
||||
assert powr(pos09999, 5, r) < 1
|
||||
assert powr(neg10001, 5, r) < -1
|
||||
assert powr(neg09999, 5, r) == -1
|
||||
assert powr(pos10001, 6, r) == 1
|
||||
assert powr(pos09999, 6, r) < 1
|
||||
assert powr(neg10001, 6, r) == 1
|
||||
assert powr(neg09999, 6, r) < 1
|
||||
|
||||
assert powr(pos10001, -5, r) < 1
|
||||
assert powr(pos09999, -5, r) == 1
|
||||
assert powr(neg10001, -5, r) == -1
|
||||
assert powr(neg09999, -5, r) < -1
|
||||
assert powr(pos10001, -6, r) < 1
|
||||
assert powr(pos09999, -6, r) == 1
|
||||
assert powr(neg10001, -6, r) < 1
|
||||
assert powr(neg09999, -6, r) == 1
|
||||
|
||||
mp.dps = 15
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def ae(a, b):
|
||||
return abs(a-b) < 10**(-mp.dps+5)
|
||||
|
||||
def test_basic_integrals():
|
||||
for prec in [15, 30, 100]:
|
||||
mp.dps = prec
|
||||
assert ae(quadts(lambda x: x**3 - 3*x**2, [-2, 4]), -12)
|
||||
assert ae(quadgl(lambda x: x**3 - 3*x**2, [-2, 4]), -12)
|
||||
assert ae(quadts(sin, [0, pi]), 2)
|
||||
assert ae(quadts(sin, [0, 2*pi]), 0)
|
||||
assert ae(quadts(exp, [-inf, -1]), 1/e)
|
||||
assert ae(quadts(lambda x: exp(-x), [0, inf]), 1)
|
||||
assert ae(quadts(lambda x: exp(-x*x), [-inf, inf]), sqrt(pi))
|
||||
assert ae(quadts(lambda x: 1/(1+x*x), [-1, 1]), pi/2)
|
||||
assert ae(quadts(lambda x: 1/(1+x*x), [-inf, inf]), pi)
|
||||
assert ae(quadts(lambda x: 2*sqrt(1-x*x), [-1, 1]), pi)
|
||||
mp.dps = 15
|
||||
|
||||
def test_quad_symmetry():
|
||||
assert quadts(sin, [-1, 1]) == 0
|
||||
assert quadgl(sin, [-1, 1]) == 0
|
||||
|
||||
def test_quadgl_linear():
|
||||
assert quadgl(lambda x: x, [0, 1], maxdegree=1).ae(0.5)
|
||||
|
||||
def test_complex_integration():
|
||||
assert quadts(lambda x: x, [0, 1+j]).ae(j)
|
||||
|
||||
def test_quadosc():
|
||||
mp.dps = 15
|
||||
assert quadosc(lambda x: sin(x)/x, [0, inf], period=2*pi).ae(pi/2)
|
||||
|
||||
# Double integrals
|
||||
def test_double_trivial():
|
||||
assert ae(quadts(lambda x, y: x, [0, 1], [0, 1]), 0.5)
|
||||
assert ae(quadts(lambda x, y: x, [-1, 1], [-1, 1]), 0.0)
|
||||
|
||||
def test_double_1():
|
||||
assert ae(quadts(lambda x, y: cos(x+y/2), [-pi/2, pi/2], [0, pi]), 4)
|
||||
|
||||
def test_double_2():
|
||||
assert ae(quadts(lambda x, y: (x-1)/((1-x*y)*log(x*y)), [0, 1], [0, 1]), euler)
|
||||
|
||||
def test_double_3():
|
||||
assert ae(quadts(lambda x, y: 1/sqrt(1+x*x+y*y), [-1, 1], [-1, 1]), 4*log(2+sqrt(3))-2*pi/3)
|
||||
|
||||
def test_double_4():
|
||||
assert ae(quadts(lambda x, y: 1/(1-x*x * y*y), [0, 1], [0, 1]), pi**2 / 8)
|
||||
|
||||
def test_double_5():
|
||||
assert ae(quadts(lambda x, y: 1/(1-x*y), [0, 1], [0, 1]), pi**2 / 6)
|
||||
|
||||
def test_double_6():
|
||||
assert ae(quadts(lambda x, y: exp(-(x+y)), [0, inf], [0, inf]), 1)
|
||||
|
||||
# fails
|
||||
def xtest_double_7():
|
||||
assert ae(quadts(lambda x, y: exp(-x*x-y*y), [-inf, inf], [-inf, inf]), pi)
|
||||
|
||||
|
||||
# Test integrals from "Experimentation in Mathematics" by Borwein,
|
||||
# Bailey & Girgensohn
|
||||
def test_expmath_integrals():
|
||||
for prec in [15, 30, 50]:
|
||||
mp.dps = prec
|
||||
assert ae(quadts(lambda x: x/sinh(x), [0, inf]), pi**2 / 4)
|
||||
assert ae(quadts(lambda x: log(x)**2 / (1+x**2), [0, inf]), pi**3 / 8)
|
||||
assert ae(quadts(lambda x: (1+x**2)/(1+x**4), [0, inf]), pi/sqrt(2))
|
||||
assert ae(quadts(lambda x: log(x)/cosh(x)**2, [0, inf]), log(pi)-2*log(2)-euler)
|
||||
assert ae(quadts(lambda x: log(1+x**3)/(1-x+x**2), [0, inf]), 2*pi*log(3)/sqrt(3))
|
||||
assert ae(quadts(lambda x: log(x)**2 / (x**2+x+1), [0, 1]), 8*pi**3 / (81*sqrt(3)))
|
||||
assert ae(quadts(lambda x: log(cos(x))**2, [0, pi/2]), pi/2 * (log(2)**2+pi**2/12))
|
||||
assert ae(quadts(lambda x: x**2 / sin(x)**2, [0, pi/2]), pi*log(2))
|
||||
assert ae(quadts(lambda x: x**2/sqrt(exp(x)-1), [0, inf]), 4*pi*(log(2)**2 + pi**2/12))
|
||||
assert ae(quadts(lambda x: x*exp(-x)*sqrt(1-exp(-2*x)), [0, inf]), pi*(1+2*log(2))/8)
|
||||
mp.dps = 15
|
||||
|
||||
# Do not reach full accuracy
|
||||
def xtest_expmath_fail():
|
||||
assert ae(quadts(lambda x: sqrt(tan(x)), [0, pi/2]), pi*sqrt(2)/2)
|
||||
assert ae(quadts(lambda x: atan(x)/(x*sqrt(1-x**2)), [0, 1]), pi*log(1+sqrt(2))/2)
|
||||
assert ae(quadts(lambda x: log(1+x**2)/x**2, [0, 1]), pi/2-log(2))
|
||||
assert ae(quadts(lambda x: x**2/((1+x**4)*sqrt(1-x**4)), [0, 1]), pi/8)
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
from mpmath import *
|
||||
from mpmath.calculus.optimization import Secant, Muller, Bisection, Illinois, \
|
||||
Pegasus, Anderson, Ridder, ANewton, Newton, MNewton, MDNewton
|
||||
|
||||
def test_findroot():
|
||||
# old tests, assuming secant
|
||||
mp.dps = 15
|
||||
assert findroot(lambda x: 4*x-3, mpf(5)).ae(0.75)
|
||||
assert findroot(sin, mpf(3)).ae(pi)
|
||||
assert findroot(sin, (mpf(3), mpf(3.14))).ae(pi)
|
||||
assert findroot(lambda x: x*x+1, mpc(2+2j)).ae(1j)
|
||||
# test all solvers with 1 starting point
|
||||
f = lambda x: cos(x)
|
||||
for solver in [Newton, Secant, MNewton, Muller, ANewton]:
|
||||
x = findroot(f, 2., solver=solver)
|
||||
assert abs(f(x)) < eps
|
||||
# test all solvers with interval of 2 points
|
||||
for solver in [Secant, Muller, Bisection, Illinois, Pegasus, Anderson,
|
||||
Ridder]:
|
||||
x = findroot(f, (1., 2.), solver=solver)
|
||||
assert abs(f(x)) < eps
|
||||
# test types
|
||||
f = lambda x: (x - 2)**2
|
||||
|
||||
#assert isinstance(findroot(f, 1, force_type=mpf, tol=1e-10), mpf)
|
||||
#assert isinstance(findroot(f, 1., force_type=None, tol=1e-10), float)
|
||||
#assert isinstance(findroot(f, 1, force_type=complex, tol=1e-10), complex)
|
||||
assert isinstance(fp.findroot(f, 1, tol=1e-10), float)
|
||||
assert isinstance(fp.findroot(f, 1+0j, tol=1e-10), complex)
|
||||
|
||||
def test_mnewton():
|
||||
f = lambda x: polyval([1,3,3,1],x)
|
||||
x = findroot(f, -0.9, solver='mnewton')
|
||||
assert abs(f(x)) < eps
|
||||
|
||||
def test_anewton():
|
||||
f = lambda x: (x - 2)**100
|
||||
x = findroot(f, 1., solver=ANewton)
|
||||
assert abs(f(x)) < eps
|
||||
|
||||
def test_muller():
|
||||
f = lambda x: (2 + x)**3 + 2
|
||||
x = findroot(f, 1., solver=Muller)
|
||||
assert abs(f(x)) < eps
|
||||
|
||||
def test_multiplicity():
|
||||
for i in xrange(1, 5):
|
||||
assert multiplicity(lambda x: (x - 1)**i, 1) == i
|
||||
assert multiplicity(lambda x: x**2, 1) == 0
|
||||
|
||||
def test_multidimensional():
|
||||
def f(*x):
|
||||
return [3*x[0]**2-2*x[1]**2-1, x[0]**2-2*x[0]+x[1]**2+2*x[1]-8]
|
||||
assert mnorm(jacobian(f, (1,-2)) - matrix([[6,8],[0,-2]]),1) < 1.e-7
|
||||
for x, error in MDNewton(mp, f, (1,-2), verbose=0,
|
||||
norm=lambda x: norm(x, inf)):
|
||||
pass
|
||||
assert norm(f(*x), 2) < 1e-14
|
||||
# The Chinese mathematician Zhu Shijie was the very first to solve this
|
||||
# nonlinear system 700 years ago
|
||||
f1 = lambda x, y: -x + 2*y
|
||||
f2 = lambda x, y: (x**2 + x*(y**2 - 2) - 4*y) / (x + 4)
|
||||
f3 = lambda x, y: sqrt(x**2 + y**2)
|
||||
def f(x, y):
|
||||
f1x = f1(x, y)
|
||||
return (f2(x, y) - f1x, f3(x, y) - f1x)
|
||||
x = findroot(f, (10, 10))
|
||||
assert [int(round(i)) for i in x] == [3, 4]
|
||||
|
||||
def test_trivial():
|
||||
assert findroot(lambda x: 0, 1) == 1
|
||||
assert findroot(lambda x: x, 0) == 0
|
||||
#assert findroot(lambda x, y: x + y, (1, -1)) == (1, -1)
|
||||
|
||||
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_special():
|
||||
assert inf == inf
|
||||
assert inf != -inf
|
||||
assert -inf == -inf
|
||||
assert inf != nan
|
||||
assert nan != nan
|
||||
assert isnan(nan)
|
||||
assert --inf == inf
|
||||
assert abs(inf) == inf
|
||||
assert abs(-inf) == inf
|
||||
assert abs(nan) != abs(nan)
|
||||
|
||||
assert isnan(inf - inf)
|
||||
assert isnan(inf + (-inf))
|
||||
assert isnan(-inf - (-inf))
|
||||
|
||||
assert isnan(inf + nan)
|
||||
assert isnan(-inf + nan)
|
||||
|
||||
assert mpf(2) + inf == inf
|
||||
assert 2 + inf == inf
|
||||
assert mpf(2) - inf == -inf
|
||||
assert 2 - inf == -inf
|
||||
|
||||
assert inf > 3
|
||||
assert 3 < inf
|
||||
assert 3 > -inf
|
||||
assert -inf < 3
|
||||
assert inf > mpf(3)
|
||||
assert mpf(3) < inf
|
||||
assert mpf(3) > -inf
|
||||
assert -inf < mpf(3)
|
||||
|
||||
assert not (nan < 3)
|
||||
assert not (nan > 3)
|
||||
|
||||
assert isnan(inf * 0)
|
||||
assert isnan(-inf * 0)
|
||||
assert inf * 3 == inf
|
||||
assert inf * -3 == -inf
|
||||
assert -inf * 3 == -inf
|
||||
assert -inf * -3 == inf
|
||||
assert inf * inf == inf
|
||||
assert -inf * -inf == inf
|
||||
|
||||
assert isnan(nan / 3)
|
||||
assert inf / -3 == -inf
|
||||
assert inf / 3 == inf
|
||||
assert 3 / inf == 0
|
||||
assert -3 / inf == 0
|
||||
assert 0 / inf == 0
|
||||
assert isnan(inf / inf)
|
||||
assert isnan(inf / -inf)
|
||||
assert isnan(inf / nan)
|
||||
|
||||
assert mpf('inf') == mpf('+inf') == inf
|
||||
assert mpf('-inf') == -inf
|
||||
assert isnan(mpf('nan'))
|
||||
|
||||
assert isinf(inf)
|
||||
assert isinf(-inf)
|
||||
assert not isinf(mpf(0))
|
||||
assert not isinf(nan)
|
||||
|
||||
def test_special_powers():
|
||||
assert inf**3 == inf
|
||||
assert isnan(inf**0)
|
||||
assert inf**-3 == 0
|
||||
assert (-inf)**2 == inf
|
||||
assert (-inf)**3 == -inf
|
||||
assert isnan((-inf)**0)
|
||||
assert (-inf)**-2 == 0
|
||||
assert (-inf)**-3 == 0
|
||||
assert isnan(nan**5)
|
||||
assert isnan(nan**0)
|
||||
|
||||
def test_functions_special():
|
||||
assert exp(inf) == inf
|
||||
assert exp(-inf) == 0
|
||||
assert isnan(exp(nan))
|
||||
assert log(inf) == inf
|
||||
assert isnan(sin(inf))
|
||||
assert isnan(sin(nan))
|
||||
assert atan(inf).ae(pi/2)
|
||||
assert atan(-inf).ae(-pi/2)
|
||||
assert isnan(sqrt(nan))
|
||||
assert sqrt(inf) == inf
|
||||
|
||||
def test_convert_special():
|
||||
float_inf = 1e300 * 1e300
|
||||
float_ninf = -float_inf
|
||||
float_nan = float_inf/float_ninf
|
||||
assert mpf(3) * float_inf == inf
|
||||
assert mpf(3) * float_ninf == -inf
|
||||
assert isnan(mpf(3) * float_nan)
|
||||
assert not (mpf(3) < float_nan)
|
||||
assert not (mpf(3) > float_nan)
|
||||
assert not (mpf(3) <= float_nan)
|
||||
assert not (mpf(3) >= float_nan)
|
||||
assert float(mpf('1e1000')) == float_inf
|
||||
assert float(mpf('-1e1000')) == float_ninf
|
||||
assert float(mpf('1e100000000000000000')) == float_inf
|
||||
assert float(mpf('-1e100000000000000000')) == float_ninf
|
||||
assert float(mpf('1e-100000000000000000')) == 0.0
|
||||
|
||||
def test_div_bug():
|
||||
assert isnan(nan/1)
|
||||
assert isnan(nan/2)
|
||||
assert inf/2 == inf
|
||||
assert (-inf)/2 == -inf
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
from mpmath import nstr, matrix, inf
|
||||
|
||||
def test_nstr():
|
||||
m = matrix([[0.75, 0.190940654, -0.0299195971],
|
||||
[0.190940654, 0.65625, 0.205663228],
|
||||
[-0.0299195971, 0.205663228, 0.64453125e-20]])
|
||||
assert nstr(m, 4, min_fixed=-inf) == \
|
||||
'''[ 0.75 0.1909 -0.02992]
|
||||
[ 0.1909 0.6563 0.2057]
|
||||
[-0.02992 0.2057 0.000000000000000000006445]'''
|
||||
assert nstr(m, 4) == \
|
||||
'''[ 0.75 0.1909 -0.02992]
|
||||
[ 0.1909 0.6563 0.2057]
|
||||
[-0.02992 0.2057 6.445e-21]'''
|
||||
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
from mpmath import *
|
||||
|
||||
def test_sumem():
|
||||
mp.dps = 15
|
||||
assert sumem(lambda k: 1/k**2.5, [50, 100]).ae(0.0012524505324784962)
|
||||
assert sumem(lambda k: k**4 + 3*k + 1, [10, 100]).ae(2050333103)
|
||||
|
||||
def test_nsum():
|
||||
mp.dps = 15
|
||||
assert nsum(lambda x: x**2, [1, 3]) == 14
|
||||
assert nsum(lambda k: 1/factorial(k), [0, inf]).ae(e)
|
||||
assert nsum(lambda k: (-1)**(k+1) / k, [1, inf]).ae(log(2))
|
||||
assert nsum(lambda k: (-1)**(k+1) / k**2, [1, inf]).ae(pi**2 / 12)
|
||||
assert nsum(lambda k: (-1)**k / log(k), [2, inf]).ae(0.9242998972229388)
|
||||
assert nsum(lambda k: 1/k**2, [1, inf]).ae(pi**2 / 6)
|
||||
assert nsum(lambda k: 2**k/fac(k), [0, inf]).ae(exp(2))
|
||||
assert nsum(lambda k: 1/k**2, [4, inf], method='e').ae(0.2838229557371153)
|
||||
|
||||
def test_nprod():
|
||||
mp.dps = 15
|
||||
assert nprod(lambda k: exp(1/k**2), [1,inf], method='r').ae(exp(pi**2/6))
|
||||
assert nprod(lambda x: x**2, [1, 3]) == 36
|
||||
|
||||
def test_fsum():
|
||||
mp.dps = 15
|
||||
assert fsum([]) == 0
|
||||
assert fsum([-4]) == -4
|
||||
assert fsum([2,3]) == 5
|
||||
assert fsum([1e-100,1]) == 1
|
||||
assert fsum([1,1e-100]) == 1
|
||||
assert fsum([1e100,1]) == 1e100
|
||||
assert fsum([1,1e100]) == 1e100
|
||||
assert fsum([1e-100,0]) == 1e-100
|
||||
assert fsum([1e-100,1e100,1e-100]) == 1e100
|
||||
assert fsum([2,1+1j,1]) == 4+1j
|
||||
assert fsum([1,mpi(2,3)]) == mpi(3,4)
|
||||
assert fsum([2,inf,3]) == inf
|
||||
assert fsum([2,-1], absolute=1) == 3
|
||||
assert fsum([2,-1], squared=1) == 5
|
||||
assert fsum([1,1+j], squared=1) == 1+2j
|
||||
assert fsum([1,3+4j], absolute=1) == 6
|
||||
assert fsum([1,2+3j], absolute=1, squared=1) == 14
|
||||
assert isnan(fsum([inf,-inf]))
|
||||
assert fsum([inf,-inf], absolute=1) == inf
|
||||
assert fsum([inf,-inf], squared=1) == inf
|
||||
assert fsum([inf,-inf], absolute=1, squared=1) == inf
|
||||
|
||||
def test_fprod():
|
||||
mp.dps = 15
|
||||
assert fprod([]) == 1
|
||||
assert fprod([2,3]) == 6
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
from mpmath import *
|
||||
from mpmath.libmp import *
|
||||
|
||||
def test_trig_misc_hard():
|
||||
mp.prec = 53
|
||||
# Worst-case input for an IEEE double, from a paper by Kahan
|
||||
x = ldexp(6381956970095103,797)
|
||||
assert cos(x) == mpf('-4.6871659242546277e-19')
|
||||
assert sin(x) == 1
|
||||
|
||||
mp.prec = 150
|
||||
a = mpf(10**50)
|
||||
mp.prec = 53
|
||||
assert sin(a).ae(-0.7896724934293100827)
|
||||
assert cos(a).ae(-0.6135286082336635622)
|
||||
|
||||
# Check relative accuracy close to x = zero
|
||||
assert sin(1e-100) == 1e-100 # when rounding to nearest
|
||||
assert sin(1e-6).ae(9.999999999998333e-007, rel_eps=2e-15, abs_eps=0)
|
||||
assert sin(1e-6j).ae(1.0000000000001666e-006j, rel_eps=2e-15, abs_eps=0)
|
||||
assert sin(-1e-6j).ae(-1.0000000000001666e-006j, rel_eps=2e-15, abs_eps=0)
|
||||
assert cos(1e-100) == 1
|
||||
assert cos(1e-6).ae(0.9999999999995)
|
||||
assert cos(-1e-6j).ae(1.0000000000005)
|
||||
assert tan(1e-100) == 1e-100
|
||||
assert tan(1e-6).ae(1.0000000000003335e-006, rel_eps=2e-15, abs_eps=0)
|
||||
assert tan(1e-6j).ae(9.9999999999966644e-007j, rel_eps=2e-15, abs_eps=0)
|
||||
assert tan(-1e-6j).ae(-9.9999999999966644e-007j, rel_eps=2e-15, abs_eps=0)
|
||||
|
||||
def test_trig_near_zero():
|
||||
mp.dps = 15
|
||||
|
||||
for r in [round_nearest, round_down, round_up, round_floor, round_ceiling]:
|
||||
assert sin(0, rounding=r) == 0
|
||||
assert cos(0, rounding=r) == 1
|
||||
|
||||
a = mpf('1e-100')
|
||||
b = mpf('-1e-100')
|
||||
|
||||
assert sin(a, rounding=round_nearest) == a
|
||||
assert sin(a, rounding=round_down) < a
|
||||
assert sin(a, rounding=round_floor) < a
|
||||
assert sin(a, rounding=round_up) >= a
|
||||
assert sin(a, rounding=round_ceiling) >= a
|
||||
assert sin(b, rounding=round_nearest) == b
|
||||
assert sin(b, rounding=round_down) > b
|
||||
assert sin(b, rounding=round_floor) <= b
|
||||
assert sin(b, rounding=round_up) <= b
|
||||
assert sin(b, rounding=round_ceiling) > b
|
||||
|
||||
assert cos(a, rounding=round_nearest) == 1
|
||||
assert cos(a, rounding=round_down) < 1
|
||||
assert cos(a, rounding=round_floor) < 1
|
||||
assert cos(a, rounding=round_up) == 1
|
||||
assert cos(a, rounding=round_ceiling) == 1
|
||||
assert cos(b, rounding=round_nearest) == 1
|
||||
assert cos(b, rounding=round_down) < 1
|
||||
assert cos(b, rounding=round_floor) < 1
|
||||
assert cos(b, rounding=round_up) == 1
|
||||
assert cos(b, rounding=round_ceiling) == 1
|
||||
|
||||
|
||||
def test_trig_near_n_pi():
|
||||
|
||||
mp.dps = 15
|
||||
a = [n*pi for n in [1, 2, 6, 11, 100, 1001, 10000, 100001]]
|
||||
mp.dps = 135
|
||||
a.append(10**100 * pi)
|
||||
mp.dps = 15
|
||||
|
||||
assert sin(a[0]) == mpf('1.2246467991473531772e-16')
|
||||
assert sin(a[1]) == mpf('-2.4492935982947063545e-16')
|
||||
assert sin(a[2]) == mpf('-7.3478807948841190634e-16')
|
||||
assert sin(a[3]) == mpf('4.8998251578625894243e-15')
|
||||
assert sin(a[4]) == mpf('1.9643867237284719452e-15')
|
||||
assert sin(a[5]) == mpf('-8.8632615209684813458e-15')
|
||||
assert sin(a[6]) == mpf('-4.8568235395684898392e-13')
|
||||
assert sin(a[7]) == mpf('3.9087342299491231029e-11')
|
||||
assert sin(a[8]) == mpf('-1.369235466754566993528e-36')
|
||||
|
||||
r = round_nearest
|
||||
assert cos(a[0], rounding=r) == -1
|
||||
assert cos(a[1], rounding=r) == 1
|
||||
assert cos(a[2], rounding=r) == 1
|
||||
assert cos(a[3], rounding=r) == -1
|
||||
assert cos(a[4], rounding=r) == 1
|
||||
assert cos(a[5], rounding=r) == -1
|
||||
assert cos(a[6], rounding=r) == 1
|
||||
assert cos(a[7], rounding=r) == -1
|
||||
assert cos(a[8], rounding=r) == 1
|
||||
|
||||
r = round_up
|
||||
assert cos(a[0], rounding=r) == -1
|
||||
assert cos(a[1], rounding=r) == 1
|
||||
assert cos(a[2], rounding=r) == 1
|
||||
assert cos(a[3], rounding=r) == -1
|
||||
assert cos(a[4], rounding=r) == 1
|
||||
assert cos(a[5], rounding=r) == -1
|
||||
assert cos(a[6], rounding=r) == 1
|
||||
assert cos(a[7], rounding=r) == -1
|
||||
assert cos(a[8], rounding=r) == 1
|
||||
|
||||
r = round_down
|
||||
assert cos(a[0], rounding=r) > -1
|
||||
assert cos(a[1], rounding=r) < 1
|
||||
assert cos(a[2], rounding=r) < 1
|
||||
assert cos(a[3], rounding=r) > -1
|
||||
assert cos(a[4], rounding=r) < 1
|
||||
assert cos(a[5], rounding=r) > -1
|
||||
assert cos(a[6], rounding=r) < 1
|
||||
assert cos(a[7], rounding=r) > -1
|
||||
assert cos(a[8], rounding=r) < 1
|
||||
|
||||
r = round_floor
|
||||
assert cos(a[0], rounding=r) == -1
|
||||
assert cos(a[1], rounding=r) < 1
|
||||
assert cos(a[2], rounding=r) < 1
|
||||
assert cos(a[3], rounding=r) == -1
|
||||
assert cos(a[4], rounding=r) < 1
|
||||
assert cos(a[5], rounding=r) == -1
|
||||
assert cos(a[6], rounding=r) < 1
|
||||
assert cos(a[7], rounding=r) == -1
|
||||
assert cos(a[8], rounding=r) < 1
|
||||
|
||||
r = round_ceiling
|
||||
assert cos(a[0], rounding=r) > -1
|
||||
assert cos(a[1], rounding=r) == 1
|
||||
assert cos(a[2], rounding=r) == 1
|
||||
assert cos(a[3], rounding=r) > -1
|
||||
assert cos(a[4], rounding=r) == 1
|
||||
assert cos(a[5], rounding=r) > -1
|
||||
assert cos(a[6], rounding=r) == 1
|
||||
assert cos(a[7], rounding=r) > -1
|
||||
assert cos(a[8], rounding=r) == 1
|
||||
|
||||
mp.dps = 15
|
||||
|
||||
if __name__ == '__main__':
|
||||
for f in globals().keys():
|
||||
if f.startswith("test_"):
|
||||
print f
|
||||
globals()[f]()
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
"""
|
||||
Limited tests of the visualization module. Right now it just makes
|
||||
sure that passing custom Axes works.
|
||||
|
||||
"""
|
||||
|
||||
from mpmath import mp, fp
|
||||
|
||||
def test_axes():
|
||||
try:
|
||||
import pylab
|
||||
except ImportError:
|
||||
print "\nSkipping test (pylab not available)\n"
|
||||
return
|
||||
fig = pylab.figure()
|
||||
axes = fig.add_subplot(111)
|
||||
for ctx in [mp, fp]:
|
||||
ctx.plot(lambda x: x**2, [0, 3], axes=axes)
|
||||
assert axes.get_xlabel() == 'x'
|
||||
assert axes.get_ylabel() == 'f(x)'
|
||||
|
||||
fig = pylab.figure()
|
||||
axes = fig.add_subplot(111)
|
||||
for ctx in [mp, fp]:
|
||||
ctx.cplot(lambda z: z, [-2, 2], [-10, 10], axes=axes)
|
||||
assert axes.get_xlabel() == 'Re(z)'
|
||||
assert axes.get_ylabel() == 'Im(z)'
|
||||
|
|
@ -1,229 +0,0 @@
|
|||
"""
|
||||
Torture tests for asymptotics and high precision evaluation of
|
||||
special functions.
|
||||
|
||||
(Other torture tests may also be placed here.)
|
||||
|
||||
Running this file (gmpy and psyco recommended!) takes several CPU minutes.
|
||||
With Python 2.6+, multiprocessing is used automatically to run tests
|
||||
in parallel if many cores are available. (A single test may take between
|
||||
a second and several minutes; possibly more.)
|
||||
|
||||
The idea:
|
||||
|
||||
* We evaluate functions at positive, negative, imaginary, 45- and 135-degree
|
||||
complex values with magnitudes between 10^-20 to 10^20, at precisions between
|
||||
5 and 150 digits (we can go even higher for fast functions).
|
||||
|
||||
* Comparing the result from two different precision levels provides
|
||||
a strong consistency check (particularly for functions that use
|
||||
different algorithms at different precision levels).
|
||||
|
||||
* That the computation finishes at all (without failure), within reasonable
|
||||
time, provides a check that evaluation works at all: that the code runs,
|
||||
that it doesn't get stuck in an infinite loop, and that it doesn't use
|
||||
some extremely slowly algorithm where it could use a faster one.
|
||||
|
||||
TODO:
|
||||
|
||||
* Speed up those functions that take long to finish!
|
||||
* Generalize to test more cases; more options.
|
||||
* Implement a timeout mechanism.
|
||||
* Some functions are notably absent, including the following:
|
||||
* inverse trigonometric functions (some become inaccurate for complex arguments)
|
||||
* ci, si (not implemented properly for large complex arguments)
|
||||
* zeta functions (need to modify test not to try too large imaginary values)
|
||||
* and others...
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import sys, os
|
||||
from timeit import default_timer as clock
|
||||
|
||||
if "-psyco" in sys.argv:
|
||||
sys.argv.remove('-psyco')
|
||||
import psyco
|
||||
psyco.full()
|
||||
|
||||
if "-nogmpy" in sys.argv:
|
||||
sys.argv.remove('-nogmpy')
|
||||
os.environ['MPMATH_NOGMPY'] = 'Y'
|
||||
|
||||
filt = ''
|
||||
if not sys.argv[-1].endswith(".py"):
|
||||
filt = sys.argv[-1]
|
||||
|
||||
from mpmath import *
|
||||
|
||||
def test_asymp(f, maxdps=150, verbose=False, huge_range=False):
|
||||
dps = [5,15,25,50,90,150,500,1500,5000,10000]
|
||||
dps = [p for p in dps if p <= maxdps]
|
||||
def check(x,y,p,inpt):
|
||||
if abs(x-y)/abs(y) < workprec(20)(power)(10, -p+1):
|
||||
return
|
||||
print
|
||||
print "Error!"
|
||||
print "Input:", inpt
|
||||
print "dps =", p
|
||||
print "Result 1:", x
|
||||
print "Result 2:", y
|
||||
print "Absolute error:", abs(x-y)
|
||||
print "Relative error:", abs(x-y)/abs(y)
|
||||
raise AssertionError
|
||||
exponents = range(-20,20)
|
||||
if huge_range:
|
||||
exponents += [-1000, -100, -50, 50, 100, 1000]
|
||||
for n in exponents:
|
||||
if verbose:
|
||||
print ".",
|
||||
mp.dps = 25
|
||||
xpos = mpf(10)**n / 1.1287
|
||||
xneg = -xpos
|
||||
ximag = xpos*j
|
||||
xcomplex1 = xpos*(1+j)
|
||||
xcomplex2 = xpos*(-1+j)
|
||||
for i in range(len(dps)):
|
||||
if verbose:
|
||||
print "Testing dps = %s" % dps[i]
|
||||
mp.dps = dps[i]
|
||||
new = f(xpos), f(xneg), f(ximag), f(xcomplex1), f(xcomplex2)
|
||||
if i != 0:
|
||||
p = dps[i-1]
|
||||
check(prev[0], new[0], p, xpos)
|
||||
check(prev[1], new[1], p, xneg)
|
||||
check(prev[2], new[2], p, ximag)
|
||||
check(prev[3], new[3], p, xcomplex1)
|
||||
check(prev[4], new[4], p, xcomplex2)
|
||||
prev = new
|
||||
if verbose:
|
||||
print
|
||||
|
||||
a1, a2, a3, a4, a5 = 1.5, -2.25, 3.125, 4, 2
|
||||
|
||||
def test_bernoulli_huge():
|
||||
p, q = bernfrac(9000)
|
||||
assert p % 10**10 == 9636701091
|
||||
assert q == 4091851784687571609141381951327092757255270
|
||||
mp.dps = 15
|
||||
assert str(bernoulli(10**100)) == '-2.58183325604736e+987675256497386331227838638980680030172857347883537824464410652557820800494271520411283004120790908623'
|
||||
mp.dps = 50
|
||||
assert str(bernoulli(10**100)) == '-2.5818332560473632073252488656039475548106223822913e+987675256497386331227838638980680030172857347883537824464410652557820800494271520411283004120790908623'
|
||||
mp.dps = 15
|
||||
|
||||
cases = """\
|
||||
test_bernoulli_huge()
|
||||
test_asymp(lambda z: +pi, maxdps=10000)
|
||||
test_asymp(lambda z: +e, maxdps=10000)
|
||||
test_asymp(lambda z: +ln2, maxdps=10000)
|
||||
test_asymp(lambda z: +ln10, maxdps=10000)
|
||||
test_asymp(lambda z: +phi, maxdps=10000)
|
||||
test_asymp(lambda z: +catalan, maxdps=5000)
|
||||
test_asymp(lambda z: +euler, maxdps=5000)
|
||||
test_asymp(lambda z: +glaisher, maxdps=1000)
|
||||
test_asymp(lambda z: +khinchin, maxdps=1000)
|
||||
test_asymp(lambda z: +twinprime, maxdps=150)
|
||||
test_asymp(lambda z: stieltjes(2), maxdps=150)
|
||||
test_asymp(lambda z: +mertens, maxdps=150)
|
||||
test_asymp(lambda z: +apery, maxdps=5000)
|
||||
test_asymp(sqrt, maxdps=10000, huge_range=True)
|
||||
test_asymp(cbrt, maxdps=5000, huge_range=True)
|
||||
test_asymp(lambda z: root(z,4), maxdps=5000, huge_range=True)
|
||||
test_asymp(lambda z: root(z,-5), maxdps=5000, huge_range=True)
|
||||
test_asymp(exp, maxdps=5000, huge_range=True)
|
||||
test_asymp(expm1, maxdps=1500)
|
||||
test_asymp(ln, maxdps=5000, huge_range=True)
|
||||
test_asymp(cosh, maxdps=5000)
|
||||
test_asymp(sinh, maxdps=5000)
|
||||
test_asymp(tanh, maxdps=1500)
|
||||
test_asymp(sin, maxdps=5000, huge_range=True)
|
||||
test_asymp(cos, maxdps=5000, huge_range=True)
|
||||
test_asymp(tan, maxdps=1500)
|
||||
test_asymp(agm, maxdps=1500, huge_range=True)
|
||||
test_asymp(ellipk, maxdps=1500)
|
||||
test_asymp(ellipe, maxdps=1500)
|
||||
test_asymp(lambertw, huge_range=True)
|
||||
test_asymp(lambda z: lambertw(z,-1))
|
||||
test_asymp(lambda z: lambertw(z,1))
|
||||
test_asymp(lambda z: lambertw(z,4))
|
||||
test_asymp(gamma)
|
||||
test_asymp(loggamma) # huge_range=True ?
|
||||
test_asymp(ei)
|
||||
test_asymp(e1)
|
||||
test_asymp(li, huge_range=True)
|
||||
test_asymp(ci)
|
||||
test_asymp(si)
|
||||
test_asymp(chi)
|
||||
test_asymp(shi)
|
||||
test_asymp(erf)
|
||||
test_asymp(erfc)
|
||||
test_asymp(erfi)
|
||||
test_asymp(lambda z: besselj(2, z))
|
||||
test_asymp(lambda z: bessely(2, z))
|
||||
test_asymp(lambda z: besseli(2, z))
|
||||
test_asymp(lambda z: besselk(2, z))
|
||||
test_asymp(lambda z: besselj(-2.25, z))
|
||||
test_asymp(lambda z: bessely(-2.25, z))
|
||||
test_asymp(lambda z: besseli(-2.25, z))
|
||||
test_asymp(lambda z: besselk(-2.25, z))
|
||||
test_asymp(airyai)
|
||||
test_asymp(airybi)
|
||||
test_asymp(lambda z: hyp0f1(a1, z))
|
||||
test_asymp(lambda z: hyp1f1(a1, a2, z))
|
||||
test_asymp(lambda z: hyp1f2(a1, a2, a3, z))
|
||||
test_asymp(lambda z: hyp2f0(a1, a2, z))
|
||||
test_asymp(lambda z: hyperu(a1, a2, z))
|
||||
test_asymp(lambda z: hyp2f1(a1, a2, a3, z))
|
||||
test_asymp(lambda z: hyp2f2(a1, a2, a3, a4, z))
|
||||
test_asymp(lambda z: hyp2f3(a1, a2, a3, a4, a5, z))
|
||||
test_asymp(lambda z: coulombf(a1, a2, z))
|
||||
test_asymp(lambda z: coulombg(a1, a2, z))
|
||||
test_asymp(lambda z: polylog(2,z))
|
||||
test_asymp(lambda z: polylog(3,z))
|
||||
test_asymp(lambda z: polylog(-2,z))
|
||||
test_asymp(lambda z: expint(4, z))
|
||||
test_asymp(lambda z: expint(-4, z))
|
||||
test_asymp(lambda z: expint(2.25, z))
|
||||
test_asymp(lambda z: gammainc(2.5, z, 5))
|
||||
test_asymp(lambda z: gammainc(2.5, 5, z))
|
||||
test_asymp(lambda z: hermite(3, z))
|
||||
test_asymp(lambda z: hermite(2.5, z))
|
||||
test_asymp(lambda z: legendre(3, z))
|
||||
test_asymp(lambda z: legendre(4, z))
|
||||
test_asymp(lambda z: legendre(2.5, z))
|
||||
test_asymp(lambda z: legenp(a1, a2, z))
|
||||
test_asymp(lambda z: legenq(a1, a2, z), maxdps=90) # abnormally slow
|
||||
test_asymp(lambda z: jtheta(1, z, 0.5))
|
||||
test_asymp(lambda z: jtheta(2, z, 0.5))
|
||||
test_asymp(lambda z: jtheta(3, z, 0.5))
|
||||
test_asymp(lambda z: jtheta(4, z, 0.5))
|
||||
test_asymp(lambda z: jtheta(1, z, 0.5, 1))
|
||||
test_asymp(lambda z: jtheta(2, z, 0.5, 1))
|
||||
test_asymp(lambda z: jtheta(3, z, 0.5, 1))
|
||||
test_asymp(lambda z: jtheta(4, z, 0.5, 1))
|
||||
test_asymp(barnesg, maxdps=90)
|
||||
"""
|
||||
|
||||
def testit(line):
|
||||
if filt in line:
|
||||
print line
|
||||
t1 = clock()
|
||||
exec line
|
||||
t2 = clock()
|
||||
elapsed = t2-t1
|
||||
print "Time:", elapsed, "for", line, "(OK)"
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
from multiprocessing import Pool
|
||||
mapf = Pool(None).map
|
||||
print "Running tests with multiprocessing"
|
||||
except ImportError:
|
||||
print "Not using multiprocessing"
|
||||
mapf = map
|
||||
t1 = clock()
|
||||
tasks = cases.splitlines()
|
||||
mapf(testit, tasks)
|
||||
t2 = clock()
|
||||
print "Cumulative wall time:", t2-t1
|
||||
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
|
||||
def monitor(f, input='print', output='print'):
|
||||
"""
|
||||
Returns a wrapped copy of *f* that monitors evaluation by calling
|
||||
*input* with every input (*args*, *kwargs*) passed to *f* and
|
||||
*output* with every value returned from *f*. The default action
|
||||
(specify using the special string value ``'print'``) is to print
|
||||
inputs and outputs to stdout, along with the total evaluation
|
||||
count::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> mp.dps = 5; mp.pretty = False
|
||||
>>> diff(monitor(exp), 1) # diff will eval f(x-h) and f(x+h)
|
||||
in 0 (mpf('0.99999999906867742538452148'),) {}
|
||||
out 0 mpf('2.7182818259274480055282064')
|
||||
in 1 (mpf('1.0000000009313225746154785'),) {}
|
||||
out 1 mpf('2.7182818309906424675501024')
|
||||
mpf('2.7182808')
|
||||
|
||||
To disable either the input or the output handler, you may
|
||||
pass *None* as argument.
|
||||
|
||||
Custom input and output handlers may be used e.g. to store
|
||||
results for later analysis::
|
||||
|
||||
>>> mp.dps = 15
|
||||
>>> input = []
|
||||
>>> output = []
|
||||
>>> findroot(monitor(sin, input.append, output.append), 3.0)
|
||||
mpf('3.1415926535897932')
|
||||
>>> len(input) # Count number of evaluations
|
||||
9
|
||||
>>> print input[3], output[3]
|
||||
((mpf('3.1415076583334066'),), {}) 8.49952562843408e-5
|
||||
>>> print input[4], output[4]
|
||||
((mpf('3.1415928201669122'),), {}) -1.66577118985331e-7
|
||||
|
||||
"""
|
||||
if not input:
|
||||
input = lambda v: None
|
||||
elif input == 'print':
|
||||
incount = [0]
|
||||
def input(value):
|
||||
args, kwargs = value
|
||||
print "in %s %r %r" % (incount[0], args, kwargs)
|
||||
incount[0] += 1
|
||||
if not output:
|
||||
output = lambda v: None
|
||||
elif output == 'print':
|
||||
outcount = [0]
|
||||
def output(value):
|
||||
print "out %s %r" % (outcount[0], value)
|
||||
outcount[0] += 1
|
||||
def f_monitored(*args, **kwargs):
|
||||
input((args, kwargs))
|
||||
v = f(*args, **kwargs)
|
||||
output(v)
|
||||
return v
|
||||
return f_monitored
|
||||
|
||||
def timing(f, *args, **kwargs):
|
||||
"""
|
||||
Returns time elapsed for evaluating ``f()``. Optionally arguments
|
||||
may be passed to time the execution of ``f(*args, **kwargs)``.
|
||||
|
||||
If the first call is very quick, ``f`` is called
|
||||
repeatedly and the best time is returned.
|
||||
"""
|
||||
once = kwargs.get('once')
|
||||
if 'once' in kwargs:
|
||||
del kwargs['once']
|
||||
if args or kwargs:
|
||||
if len(args) == 1 and not kwargs:
|
||||
arg = args[0]
|
||||
g = lambda: f(arg)
|
||||
else:
|
||||
g = lambda: f(*args, **kwargs)
|
||||
else:
|
||||
g = f
|
||||
from timeit import default_timer as clock
|
||||
t1=clock(); v=g(); t2=clock(); t=t2-t1
|
||||
if t > 0.05 or once:
|
||||
return t
|
||||
for i in range(3):
|
||||
t1=clock();
|
||||
# Evaluate multiple times because the timer function
|
||||
# has a significant overhead
|
||||
g();g();g();g();g();g();g();g();g();g()
|
||||
t2=clock()
|
||||
t=min(t,(t2-t1)/10)
|
||||
return t
|
||||
|
|
@ -1,270 +0,0 @@
|
|||
"""
|
||||
Plotting (requires matplotlib)
|
||||
"""
|
||||
|
||||
from colorsys import hsv_to_rgb, hls_to_rgb
|
||||
|
||||
class VisualizationMethods(object):
|
||||
plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError)
|
||||
|
||||
def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
|
||||
singularities=[], axes=None):
|
||||
r"""
|
||||
Shows a simple 2D plot of a function `f(x)` or list of functions
|
||||
`[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
|
||||
specified by *xlim*. Some examples::
|
||||
|
||||
plot(lambda x: exp(x)*li(x), [1, 4])
|
||||
plot([cos, sin], [-4, 4])
|
||||
plot([fresnels, fresnelc], [-4, 4])
|
||||
plot([sqrt, cbrt], [-4, 4])
|
||||
plot(lambda t: zeta(0.5+t*j), [-20, 20])
|
||||
plot([floor, ceil, abs, sign], [-5, 5])
|
||||
|
||||
Points where the function raises a numerical exception or
|
||||
returns an infinite value are removed from the graph.
|
||||
Singularities can also be excluded explicitly
|
||||
as follows (useful for removing erroneous vertical lines)::
|
||||
|
||||
plot(cot, ylim=[-5, 5]) # bad
|
||||
plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good
|
||||
|
||||
For parts where the function assumes complex values, the
|
||||
real part is plotted with dashes and the imaginary part
|
||||
is plotted with dots.
|
||||
|
||||
NOTE: This function requires matplotlib (pylab).
|
||||
"""
|
||||
if file:
|
||||
axes = None
|
||||
fig = None
|
||||
if not axes:
|
||||
import pylab
|
||||
fig = pylab.figure()
|
||||
axes = fig.add_subplot(111)
|
||||
if not isinstance(f, (tuple, list)):
|
||||
f = [f]
|
||||
a, b = xlim
|
||||
colors = ['b', 'r', 'g', 'm', 'k']
|
||||
for n, func in enumerate(f):
|
||||
x = ctx.arange(a, b, (b-a)/float(points))
|
||||
segments = []
|
||||
segment = []
|
||||
in_complex = False
|
||||
for i in xrange(len(x)):
|
||||
try:
|
||||
if i != 0:
|
||||
for sing in singularities:
|
||||
if x[i-1] <= sing and x[i] >= sing:
|
||||
raise ValueError
|
||||
v = func(x[i])
|
||||
if ctx.isnan(v) or abs(v) > 1e300:
|
||||
raise ValueError
|
||||
if hasattr(v, "imag") and v.imag:
|
||||
re = float(v.real)
|
||||
im = float(v.imag)
|
||||
if not in_complex:
|
||||
in_complex = True
|
||||
segments.append(segment)
|
||||
segment = []
|
||||
segment.append((float(x[i]), re, im))
|
||||
else:
|
||||
if in_complex:
|
||||
in_complex = False
|
||||
segments.append(segment)
|
||||
segment = []
|
||||
segment.append((float(x[i]), v))
|
||||
except ctx.plot_ignore:
|
||||
if segment:
|
||||
segments.append(segment)
|
||||
segment = []
|
||||
if segment:
|
||||
segments.append(segment)
|
||||
for segment in segments:
|
||||
x = [s[0] for s in segment]
|
||||
y = [s[1] for s in segment]
|
||||
if not x:
|
||||
continue
|
||||
c = colors[n % len(colors)]
|
||||
if len(segment[0]) == 3:
|
||||
z = [s[2] for s in segment]
|
||||
axes.plot(x, y, '--'+c, linewidth=3)
|
||||
axes.plot(x, z, ':'+c, linewidth=3)
|
||||
else:
|
||||
axes.plot(x, y, c, linewidth=3)
|
||||
axes.set_xlim(xlim)
|
||||
if ylim:
|
||||
axes.set_ylim(ylim)
|
||||
axes.set_xlabel('x')
|
||||
axes.set_ylabel('f(x)')
|
||||
axes.grid(True)
|
||||
if fig:
|
||||
if file:
|
||||
pylab.savefig(file, dpi=dpi)
|
||||
else:
|
||||
pylab.show()
|
||||
|
||||
def default_color_function(ctx, z):
|
||||
if ctx.isinf(z):
|
||||
return (1.0, 1.0, 1.0)
|
||||
if ctx.isnan(z):
|
||||
return (0.5, 0.5, 0.5)
|
||||
pi = 3.1415926535898
|
||||
a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
|
||||
a = (a + 0.5) % 1.0
|
||||
b = 1.0 - float(1/(1.0+abs(z)**0.3))
|
||||
return hls_to_rgb(a, b, 0.8)
|
||||
|
||||
def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
|
||||
verbose=False, file=None, dpi=None, axes=None):
|
||||
"""
|
||||
Plots the given complex-valued function *f* over a rectangular part
|
||||
of the complex plane specified by the pairs of intervals *re* and *im*.
|
||||
For example::
|
||||
|
||||
cplot(lambda z: z, [-2, 2], [-10, 10])
|
||||
cplot(exp)
|
||||
cplot(zeta, [0, 1], [0, 50])
|
||||
|
||||
By default, the complex argument (phase) is shown as color (hue) and
|
||||
the magnitude is show as brightness. You can also supply a
|
||||
custom color function (*color*). This function should take a
|
||||
complex number as input and return an RGB 3-tuple containing
|
||||
floats in the range 0.0-1.0.
|
||||
|
||||
To obtain a sharp image, the number of points may need to be
|
||||
increased to 100,000 or thereabout. Since evaluating the
|
||||
function that many times is likely to be slow, the 'verbose'
|
||||
option is useful to display progress.
|
||||
|
||||
NOTE: This function requires matplotlib (pylab).
|
||||
"""
|
||||
if color is None:
|
||||
color = ctx.default_color_function
|
||||
import pylab
|
||||
if file:
|
||||
axes = None
|
||||
fig = None
|
||||
if not axes:
|
||||
fig = pylab.figure()
|
||||
axes = fig.add_subplot(111)
|
||||
rea, reb = re
|
||||
ima, imb = im
|
||||
dre = reb - rea
|
||||
dim = imb - ima
|
||||
M = int(ctx.sqrt(points*dre/dim)+1)
|
||||
N = int(ctx.sqrt(points*dim/dre)+1)
|
||||
x = pylab.linspace(rea, reb, M)
|
||||
y = pylab.linspace(ima, imb, N)
|
||||
# Note: we have to be careful to get the right rotation.
|
||||
# Test with these plots:
|
||||
# cplot(lambda z: z if z.real < 0 else 0)
|
||||
# cplot(lambda z: z if z.imag < 0 else 0)
|
||||
w = pylab.zeros((N, M, 3))
|
||||
for n in xrange(N):
|
||||
for m in xrange(M):
|
||||
z = ctx.mpc(x[m], y[n])
|
||||
try:
|
||||
v = color(f(z))
|
||||
except ctx.plot_ignore:
|
||||
v = (0.5, 0.5, 0.5)
|
||||
w[n,m] = v
|
||||
if verbose:
|
||||
print n, "of", N
|
||||
axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
|
||||
axes.set_xlabel('Re(z)')
|
||||
axes.set_ylabel('Im(z)')
|
||||
if fig:
|
||||
if file:
|
||||
pylab.savefig(file, dpi=dpi)
|
||||
else:
|
||||
pylab.show()
|
||||
|
||||
def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
|
||||
wireframe=False, file=None, dpi=None, axes=None):
|
||||
"""
|
||||
Plots the surface defined by `f`.
|
||||
|
||||
If `f` returns a single component, then this plots the surface
|
||||
defined by `z = f(x,y)` over the rectangular domain with
|
||||
`x = u` and `y = v`.
|
||||
|
||||
If `f` returns three components, then this plots the parametric
|
||||
surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
|
||||
|
||||
For example, to plot a simple function::
|
||||
|
||||
>>> from mpmath import *
|
||||
>>> f = lambda x, y: sin(x+y)*cos(y)
|
||||
>>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP
|
||||
|
||||
Plotting a donut::
|
||||
|
||||
>>> r, R = 1, 2.5
|
||||
>>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
|
||||
>>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP
|
||||
|
||||
NOTE: This function requires matplotlib (pylab) 0.98.5.3 or higher.
|
||||
"""
|
||||
import pylab
|
||||
import mpl_toolkits.mplot3d as mplot3d
|
||||
if file:
|
||||
axes = None
|
||||
fig = None
|
||||
if not axes:
|
||||
fig = pylab.figure()
|
||||
axes = mplot3d.axes3d.Axes3D(fig)
|
||||
ua, ub = u
|
||||
va, vb = v
|
||||
du = ub - ua
|
||||
dv = vb - va
|
||||
if not isinstance(points, (list, tuple)):
|
||||
points = [points, points]
|
||||
M, N = points
|
||||
u = pylab.linspace(ua, ub, M)
|
||||
v = pylab.linspace(va, vb, N)
|
||||
x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
|
||||
xab, yab, zab = [[0, 0] for i in xrange(3)]
|
||||
for n in xrange(N):
|
||||
for m in xrange(M):
|
||||
fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
|
||||
try:
|
||||
x[m,n], y[m,n], z[m,n] = fdata
|
||||
except TypeError:
|
||||
x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
|
||||
for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
|
||||
if c < cab[0]:
|
||||
cab[0] = c
|
||||
if c > cab[1]:
|
||||
cab[1] = c
|
||||
if wireframe:
|
||||
axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
|
||||
else:
|
||||
axes.plot_surface(x, y, z, rstride=4, cstride=4)
|
||||
axes.set_xlabel('x')
|
||||
axes.set_ylabel('y')
|
||||
axes.set_zlabel('z')
|
||||
if keep_aspect:
|
||||
dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
|
||||
maxd = max(dx, dy, dz)
|
||||
if dx < maxd:
|
||||
delta = maxd - dx
|
||||
axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
|
||||
if dy < maxd:
|
||||
delta = maxd - dy
|
||||
axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
|
||||
if dz < maxd:
|
||||
delta = maxd - dz
|
||||
axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
|
||||
if fig:
|
||||
if file:
|
||||
pylab.savefig(file, dpi=dpi)
|
||||
else:
|
||||
pylab.show()
|
||||
|
||||
|
||||
VisualizationMethods.plot = plot
|
||||
VisualizationMethods.default_color_function = default_color_function
|
||||
VisualizationMethods.cplot = cplot
|
||||
VisualizationMethods.splot = splot
|
||||
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
# -*- coding: ISO-8859-1 -*-
|
||||
#
|
||||
#
|
||||
# Copyright (C) 2002-2005 Jörg Lehmann <joergl@users.sourceforge.net>
|
||||
# Copyright (C) 2002-2006 André Wobst <wobsta@users.sourceforge.net>
|
||||
# Copyright (C) 2002-2005 Jorg Lehmann <joergl@users.sourceforge.net>
|
||||
# Copyright (C) 2002-2006 Andre Wobst <wobsta@users.sourceforge.net>
|
||||
#
|
||||
# This file is part of PyX (http://pyx.sourceforge.net/).
|
||||
#
|
||||
|
|
@ -28,8 +27,8 @@ interface. Complex tasks like 2d and 3d plots in publication-ready quality are
|
|||
built out of these primitives.
|
||||
"""
|
||||
|
||||
import version
|
||||
__version__ = version.version
|
||||
from .version import version
|
||||
__version__ = version
|
||||
|
||||
__all__ = ["attr", "box", "bitmap", "canvas", "color", "connector", "deco", "deformer", "document",
|
||||
"epsfile", "graph", "mesh", "path", "pattern", "style", "trafo", "text", "unit"]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env python2.7
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This script will generate a stimulus file for a given period, load, and slew input
|
||||
for the given dimension SRAM. It is useful for debugging after an SRAM has been
|
||||
|
|
|
|||
|
|
@ -89,11 +89,11 @@ def print_banner():
|
|||
def check_versions():
|
||||
""" Run some checks of required software versions. """
|
||||
|
||||
# check that we are not using version 3 and at least 2.7
|
||||
# Now require python >=3.6
|
||||
major_python_version = sys.version_info.major
|
||||
minor_python_version = sys.version_info.minor
|
||||
if not (major_python_version == 2 and minor_python_version >= 7):
|
||||
debug.error("Python 2.7 is required.",-1)
|
||||
if not (major_python_version == 3 and minor_python_version >= 6):
|
||||
debug.error("Python 3.6 or greater is required.",-1)
|
||||
|
||||
# FIXME: Check versions of other tools here??
|
||||
# or, this could be done in each module (e.g. verify, characterizer, etc.)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class bank(design.design):
|
|||
"bitcell_array", "sense_amp_array", "precharge_array",
|
||||
"column_mux_array", "write_driver_array", "tri_gate_array",
|
||||
"bank_select"]
|
||||
from importlib import reload
|
||||
for mod_name in mod_list:
|
||||
config_mod_name = getattr(OPTS, mod_name)
|
||||
class_file = reload(__import__(config_mod_name))
|
||||
|
|
@ -130,8 +131,8 @@ class bank(design.design):
|
|||
def compute_sizes(self):
|
||||
""" Computes the required sizes to create the bank """
|
||||
|
||||
self.num_cols = self.words_per_row*self.word_size
|
||||
self.num_rows = self.num_words / self.words_per_row
|
||||
self.num_cols = int(self.words_per_row*self.word_size)
|
||||
self.num_rows = int(self.num_words / self.words_per_row)
|
||||
|
||||
self.row_addr_size = int(log(self.num_rows, 2))
|
||||
self.col_addr_size = int(log(self.words_per_row, 2))
|
||||
|
|
@ -320,7 +321,7 @@ class bank(design.design):
|
|||
y_offset = self.sense_amp_array.height+self.column_mux_height \
|
||||
+ self.write_driver_array.height + self.m2_gap + self.tri_gate_array.height
|
||||
self.tri_gate_array_inst=self.add_inst(name="tri_gate_array",
|
||||
mod=self.tri_gate_array,
|
||||
mod=self.tri_gate_array,
|
||||
offset=vector(0,y_offset).scale(-1,-1))
|
||||
|
||||
temp = []
|
||||
|
|
@ -852,9 +853,7 @@ class bank(design.design):
|
|||
|
||||
def analytical_delay(self, slew, load):
|
||||
""" return analytical delay of the bank"""
|
||||
msf_addr_delay = self.msf_address.analytical_delay(slew, self.row_decoder.input_load())
|
||||
|
||||
decoder_delay = self.row_decoder.analytical_delay(msf_addr_delay.slew, self.wordline_driver.input_load())
|
||||
decoder_delay = self.row_decoder.analytical_delay(slew, self.wordline_driver.input_load())
|
||||
|
||||
word_driver_delay = self.wordline_driver.analytical_delay(decoder_delay.slew, self.bitcell_array.input_load())
|
||||
|
||||
|
|
@ -866,7 +865,6 @@ class bank(design.design):
|
|||
|
||||
data_t_DATA_delay = self.tri_gate_array.analytical_delay(bl_t_data_out_delay.slew, load)
|
||||
|
||||
result = msf_addr_delay + decoder_delay + word_driver_delay \
|
||||
+ bitcell_array_delay + bl_t_data_out_delay + data_t_DATA_delay
|
||||
result = decoder_delay + word_driver_delay + bitcell_array_delay + bl_t_data_out_delay + data_t_DATA_delay
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class bitcell_array(design.design):
|
|||
self.column_size = cols
|
||||
self.row_size = rows
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.bitcell))
|
||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||
self.cell = self.mod_bitcell()
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ class control_logic(design.design):
|
|||
self.inv8 = pinv(size=16, height=dff_height)
|
||||
self.add_mod(self.inv8)
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.replica_bitline))
|
||||
replica_bitline = getattr(c, OPTS.replica_bitline)
|
||||
# FIXME: These should be tuned according to the size!
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class delay_chain(design.design):
|
|||
self.num_inverters = 1 + sum(fanout_list)
|
||||
self.num_top_half = round(self.num_inverters / 2.0)
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.bitcell))
|
||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||
self.bitcell = self.mod_bitcell()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class dff_array(design.design):
|
|||
design.design.__init__(self, name)
|
||||
debug.info(1, "Creating {}".format(self.name))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.dff))
|
||||
self.mod_dff = getattr(c, OPTS.dff)
|
||||
self.dff = self.mod_dff("dff")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class dff_buf(design.design):
|
|||
design.design.__init__(self, name)
|
||||
debug.info(1, "Creating {}".format(self.name))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.dff))
|
||||
self.mod_dff = getattr(c, OPTS.dff)
|
||||
self.dff = self.mod_dff("dff")
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class dff_inv(design.design):
|
|||
design.design.__init__(self, name)
|
||||
debug.info(1, "Creating {}".format(self.name))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.dff))
|
||||
self.mod_dff = getattr(c, OPTS.dff)
|
||||
self.dff = self.mod_dff("dff")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class hierarchical_decoder(design.design):
|
|||
def __init__(self, rows):
|
||||
design.design.__init__(self, "hierarchical_decoder_{0}rows".format(rows))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.bitcell))
|
||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||
self.bitcell_height = self.mod_bitcell.height
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class hierarchical_predecode(design.design):
|
|||
self.number_of_outputs = int(math.pow(2, self.number_of_inputs))
|
||||
design.design.__init__(self, name="pre{0}x{1}".format(self.number_of_inputs,self.number_of_outputs))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.bitcell))
|
||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class ms_flop_array(design.design):
|
|||
design.design.__init__(self, name)
|
||||
debug.info(1, "Creating {}".format(self.name))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.ms_flop))
|
||||
self.mod_ms_flop = getattr(c, OPTS.ms_flop)
|
||||
self.ms = self.mod_ms_flop("ms_flop")
|
||||
|
|
@ -27,7 +28,7 @@ class ms_flop_array(design.design):
|
|||
|
||||
self.width = self.columns * self.ms.width
|
||||
self.height = self.ms.height
|
||||
self.words_per_row = self.columns / self.word_size
|
||||
self.words_per_row = int(self.columns / self.word_size)
|
||||
|
||||
self.create_layout()
|
||||
|
||||
|
|
@ -57,13 +58,16 @@ class ms_flop_array(design.design):
|
|||
else:
|
||||
base = vector((i+1)*self.ms.width,0)
|
||||
mirror = "MY"
|
||||
self.ms_inst[i/self.words_per_row]=self.add_inst(name=name,
|
||||
|
||||
index = int(i/self.words_per_row)
|
||||
|
||||
self.ms_inst[index]=self.add_inst(name=name,
|
||||
mod=self.ms,
|
||||
offset=base,
|
||||
mirror=mirror)
|
||||
self.connect_inst(["din[{0}]".format(i/self.words_per_row),
|
||||
"dout[{0}]".format(i/self.words_per_row),
|
||||
"dout_bar[{0}]".format(i/self.words_per_row),
|
||||
self.connect_inst(["din[{0}]".format(index),
|
||||
"dout[{0}]".format(index),
|
||||
"dout_bar[{0}]".format(index),
|
||||
"clk",
|
||||
"vdd", "gnd"])
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class replica_bitline(design.design):
|
|||
def __init__(self, delay_stages, delay_fanout, bitcell_loads, name="replica_bitline"):
|
||||
design.design.__init__(self, name)
|
||||
|
||||
from importlib import reload
|
||||
g = reload(__import__(OPTS.delay_chain))
|
||||
self.mod_delay_chain = getattr(g, OPTS.delay_chain)
|
||||
|
||||
|
|
@ -132,11 +133,10 @@ class replica_bitline(design.design):
|
|||
""" Connect all the signals together """
|
||||
self.route_vdd()
|
||||
self.route_gnd()
|
||||
self.route_vdd_gnd()
|
||||
self.route_access_tx()
|
||||
|
||||
def route_vdd_gnd(self):
|
||||
""" Route all the vdd and gnd pins to the top level """
|
||||
def route_vdd_gnd(self):
|
||||
""" Propagate all vdd/gnd pins up to this level for all modules """
|
||||
|
||||
# These are the instances that every bank has
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ class sense_amp_array(design.design):
|
|||
design.design.__init__(self, "sense_amp_array")
|
||||
debug.info(1, "Creating {0}".format(self.name))
|
||||
|
||||
from importlib import reload
|
||||
c = reload(__import__(OPTS.sense_amp))
|
||||
self.mod_sense_amp = getattr(c, OPTS.sense_amp)
|
||||
self.amp = self.mod_sense_amp("sense_amp")
|
||||
|
|
@ -33,7 +34,8 @@ class sense_amp_array(design.design):
|
|||
def add_pins(self):
|
||||
|
||||
for i in range(0,self.row_size,self.words_per_row):
|
||||
self.add_pin("data[{0}]".format(i/self.words_per_row))
|
||||
index = int(i/self.words_per_row)
|
||||
self.add_pin("data[{0}]".format(index))
|
||||
self.add_pin("bl[{0}]".format(i))
|
||||
self.add_pin("br[{0}]".format(i))
|
||||
|
||||
|
|
@ -62,12 +64,14 @@ class sense_amp_array(design.design):
|
|||
br_offset = amp_position + br_pin.ll().scale(1,0)
|
||||
dout_offset = amp_position + dout_pin.ll()
|
||||
|
||||
index = int(i/self.words_per_row)
|
||||
|
||||
inst = self.add_inst(name=name,
|
||||
mod=self.amp,
|
||||
offset=amp_position)
|
||||
self.connect_inst(["bl[{0}]".format(i),
|
||||
"br[{0}]".format(i),
|
||||
"data[{0}]".format(i/self.words_per_row),
|
||||
"data[{0}]".format(index),
|
||||
"en", "vdd", "gnd"])
|
||||
|
||||
|
||||
|
|
@ -85,19 +89,18 @@ class sense_amp_array(design.design):
|
|||
layer="metal3",
|
||||
offset=vdd_pos)
|
||||
|
||||
|
||||
self.add_layout_pin(text="bl[{0}]".format(i/self.words_per_row),
|
||||
self.add_layout_pin(text="bl[{0}]".format(i),
|
||||
layer="metal2",
|
||||
offset=bl_offset,
|
||||
width=bl_pin.width(),
|
||||
height=bl_pin.height())
|
||||
self.add_layout_pin(text="br[{0}]".format(i/self.words_per_row),
|
||||
self.add_layout_pin(text="br[{0}]".format(i),
|
||||
layer="metal2",
|
||||
offset=br_offset,
|
||||
width=br_pin.width(),
|
||||
height=br_pin.height())
|
||||
|
||||
self.add_layout_pin(text="data[{0}]".format(i/self.words_per_row),
|
||||
self.add_layout_pin(text="data[{0}]".format(index),
|
||||
layer="metal2",
|
||||
offset=dout_offset,
|
||||
width=dout_pin.width(),
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue