mirror of https://github.com/VLSIDA/OpenRAM.git
commiting changes from most recent pull from dev
This commit is contained in:
commit
8f131ddb2f
|
|
@ -29,18 +29,18 @@ class design(hierarchy_spice.spice, hierarchy_layout.layout):
|
||||||
# because each reference must be a unique name.
|
# because each reference must be a unique name.
|
||||||
# These modules ensure unique names or have no changes if they
|
# These modules ensure unique names or have no changes if they
|
||||||
# aren't unique
|
# aren't unique
|
||||||
ok_list = ['ms_flop.ms_flop',
|
ok_list = ['ms_flop',
|
||||||
'dff.dff',
|
'dff',
|
||||||
'dff_buf.dff_buf',
|
'dff_buf',
|
||||||
'bitcell.bitcell',
|
'bitcell',
|
||||||
'contact.contact',
|
'contact',
|
||||||
'ptx.ptx',
|
'ptx',
|
||||||
'sram.sram',
|
'sram',
|
||||||
'hierarchical_predecode2x4.hierarchical_predecode2x4',
|
'hierarchical_predecode2x4',
|
||||||
'hierarchical_predecode3x8.hierarchical_predecode3x8']
|
'hierarchical_predecode3x8']
|
||||||
if name not in design.name_map:
|
if name not in design.name_map:
|
||||||
design.name_map.append(name)
|
design.name_map.append(name)
|
||||||
elif str(self.__class__) in ok_list:
|
elif self.__class__.__name__ in ok_list:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
debug.error("Duplicate layout reference name {0} of class {1}. GDS2 requires names be unique.".format(name,self.__class__),-1)
|
debug.error("Duplicate layout reference name {0} of class {1}. GDS2 requires names be unique.".format(name,self.__class__),-1)
|
||||||
|
|
|
||||||
|
|
@ -116,10 +116,11 @@ class spice(verilog.verilog):
|
||||||
self.spice = f.readlines()
|
self.spice = f.readlines()
|
||||||
for i in range(len(self.spice)):
|
for i in range(len(self.spice)):
|
||||||
self.spice[i] = self.spice[i].rstrip(" \n")
|
self.spice[i] = self.spice[i].rstrip(" \n")
|
||||||
|
f.close()
|
||||||
|
|
||||||
# find the correct subckt line in the file
|
# find the correct subckt line in the file
|
||||||
subckt = re.compile("^.subckt {}".format(self.name), re.IGNORECASE)
|
subckt = re.compile("^.subckt {}".format(self.name), re.IGNORECASE)
|
||||||
subckt_line = filter(subckt.search, self.spice)[0]
|
subckt_line = list(filter(subckt.search, self.spice))[0]
|
||||||
# parses line into ports and remove subckt
|
# parses line into ports and remove subckt
|
||||||
self.pins = subckt_line.split(" ")[2:]
|
self.pins = subckt_line.split(" ")[2:]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ class pin_layout:
|
||||||
self.rect = [x.snap_to_grid() for x in self.rect]
|
self.rect = [x.snap_to_grid() for x in self.rect]
|
||||||
# if it's a layer number look up the layer name. this assumes a unique layer number.
|
# if it's a layer number look up the layer name. this assumes a unique layer number.
|
||||||
if type(layer_name_num)==int:
|
if type(layer_name_num)==int:
|
||||||
self.layer = layer.keys()[layer.values().index(layer_name_num)]
|
self.layer = list(layer.keys())[list(layer.values()).index(layer_name_num)]
|
||||||
else:
|
else:
|
||||||
self.layer=layer_name_num
|
self.layer=layer_name_num
|
||||||
self.layer_num = layer[self.layer]
|
self.layer_num = layer[self.layer]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import debug
|
import debug
|
||||||
from globals import OPTS,find_exe,get_tool
|
from globals import OPTS,find_exe,get_tool
|
||||||
import lib
|
from .lib import *
|
||||||
import delay
|
from .delay import *
|
||||||
import setup_hold
|
from .setup_hold import *
|
||||||
|
|
||||||
|
|
||||||
debug.info(2,"Initializing characterizer...")
|
debug.info(2,"Initializing characterizer...")
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import sys,re,shutil
|
||||||
import debug
|
import debug
|
||||||
import tech
|
import tech
|
||||||
import math
|
import math
|
||||||
import stimuli
|
from .stimuli import *
|
||||||
from trim_spice import trim_spice
|
from .trim_spice import *
|
||||||
import charutils as ch
|
from .charutils import *
|
||||||
import utils
|
import utils
|
||||||
from globals import OPTS
|
from globals import OPTS
|
||||||
|
|
||||||
|
|
@ -101,7 +101,7 @@ class delay():
|
||||||
self.sf.write("* Delay stimulus for period of {0}n load={1}fF slew={2}ns\n\n".format(self.period,
|
self.sf.write("* Delay stimulus for period of {0}n load={1}fF slew={2}ns\n\n".format(self.period,
|
||||||
self.load,
|
self.load,
|
||||||
self.slew))
|
self.slew))
|
||||||
self.stim = stimuli.stimuli(self.sf, self.corner)
|
self.stim = stimuli(self.sf, self.corner)
|
||||||
# include files in stimulus file
|
# include files in stimulus file
|
||||||
self.stim.write_include(self.trim_sp_file)
|
self.stim.write_include(self.trim_sp_file)
|
||||||
|
|
||||||
|
|
@ -339,16 +339,16 @@ class delay():
|
||||||
# Checking from not data_value to data_value
|
# Checking from not data_value to data_value
|
||||||
self.write_delay_stimulus()
|
self.write_delay_stimulus()
|
||||||
self.stim.run_sim()
|
self.stim.run_sim()
|
||||||
delay_hl = ch.parse_output("timing", "delay_hl")
|
delay_hl = parse_output("timing", "delay_hl")
|
||||||
delay_lh = ch.parse_output("timing", "delay_lh")
|
delay_lh = parse_output("timing", "delay_lh")
|
||||||
slew_hl = ch.parse_output("timing", "slew_hl")
|
slew_hl = parse_output("timing", "slew_hl")
|
||||||
slew_lh = ch.parse_output("timing", "slew_lh")
|
slew_lh = parse_output("timing", "slew_lh")
|
||||||
delays = (delay_hl, delay_lh, slew_hl, slew_lh)
|
delays = (delay_hl, delay_lh, slew_hl, slew_lh)
|
||||||
|
|
||||||
read0_power=ch.parse_output("timing", "read0_power")
|
read0_power=parse_output("timing", "read0_power")
|
||||||
write0_power=ch.parse_output("timing", "write0_power")
|
write0_power=parse_output("timing", "write0_power")
|
||||||
read1_power=ch.parse_output("timing", "read1_power")
|
read1_power=parse_output("timing", "read1_power")
|
||||||
write1_power=ch.parse_output("timing", "write1_power")
|
write1_power=parse_output("timing", "write1_power")
|
||||||
|
|
||||||
if not self.check_valid_delays(delays):
|
if not self.check_valid_delays(delays):
|
||||||
return (False,{})
|
return (False,{})
|
||||||
|
|
@ -378,22 +378,24 @@ class delay():
|
||||||
|
|
||||||
self.write_power_stimulus(trim=False)
|
self.write_power_stimulus(trim=False)
|
||||||
self.stim.run_sim()
|
self.stim.run_sim()
|
||||||
leakage_power=ch.parse_output("timing", "leakage_power")
|
leakage_power=parse_output("timing", "leakage_power")
|
||||||
debug.check(leakage_power!="Failed","Could not measure leakage power.")
|
debug.check(leakage_power!="Failed","Could not measure leakage power.")
|
||||||
|
|
||||||
|
|
||||||
self.write_power_stimulus(trim=True)
|
self.write_power_stimulus(trim=True)
|
||||||
self.stim.run_sim()
|
self.stim.run_sim()
|
||||||
trim_leakage_power=ch.parse_output("timing", "leakage_power")
|
trim_leakage_power=parse_output("timing", "leakage_power")
|
||||||
debug.check(trim_leakage_power!="Failed","Could not measure leakage power.")
|
debug.check(trim_leakage_power!="Failed","Could not measure leakage power.")
|
||||||
|
|
||||||
# For debug, you sometimes want to inspect each simulation.
|
# For debug, you sometimes want to inspect each simulation.
|
||||||
#key=raw_input("press return to continue")
|
#key=raw_input("press return to continue")
|
||||||
return (leakage_power*1e3, trim_leakage_power*1e3)
|
return (leakage_power*1e3, trim_leakage_power*1e3)
|
||||||
|
|
||||||
def check_valid_delays(self, (delay_hl, delay_lh, slew_hl, slew_lh)):
|
def check_valid_delays(self, delay_tuple):
|
||||||
""" Check if the measurements are defined and if they are valid. """
|
""" Check if the measurements are defined and if they are valid. """
|
||||||
|
|
||||||
|
(delay_hl, delay_lh, slew_hl, slew_lh) = delay_tuple
|
||||||
|
|
||||||
# if it failed or the read was longer than a period
|
# if it failed or the read was longer than a period
|
||||||
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
|
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
|
||||||
debug.info(2,"Failed simulation: period {0} load {1} slew {2}, delay_hl={3}n delay_lh={4}ns slew_hl={5}n slew_lh={6}n".format(self.period,
|
debug.info(2,"Failed simulation: period {0} load {1} slew {2}, delay_hl={3}n delay_lh={4}ns slew_hl={5}n slew_lh={6}n".format(self.period,
|
||||||
|
|
@ -457,7 +459,7 @@ class delay():
|
||||||
else:
|
else:
|
||||||
lb_period = target_period
|
lb_period = target_period
|
||||||
|
|
||||||
if ch.relative_compare(ub_period, lb_period, error_tolerance=0.05):
|
if relative_compare(ub_period, lb_period, error_tolerance=0.05):
|
||||||
# ub_period is always feasible
|
# ub_period is always feasible
|
||||||
return ub_period
|
return ub_period
|
||||||
|
|
||||||
|
|
@ -471,10 +473,10 @@ class delay():
|
||||||
# Checking from not data_value to data_value
|
# Checking from not data_value to data_value
|
||||||
self.write_delay_stimulus()
|
self.write_delay_stimulus()
|
||||||
self.stim.run_sim()
|
self.stim.run_sim()
|
||||||
delay_hl = ch.parse_output("timing", "delay_hl")
|
delay_hl = parse_output("timing", "delay_hl")
|
||||||
delay_lh = ch.parse_output("timing", "delay_lh")
|
delay_lh = parse_output("timing", "delay_lh")
|
||||||
slew_hl = ch.parse_output("timing", "slew_hl")
|
slew_hl = parse_output("timing", "slew_hl")
|
||||||
slew_lh = ch.parse_output("timing", "slew_lh")
|
slew_lh = parse_output("timing", "slew_lh")
|
||||||
# if it failed or the read was longer than a period
|
# if it failed or the read was longer than a period
|
||||||
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
|
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
|
||||||
debug.info(2,"Invalid measures: Period {0}, delay_hl={1}ns, delay_lh={2}ns slew_hl={3}ns slew_lh={4}ns".format(self.period,
|
debug.info(2,"Invalid measures: Period {0}, delay_hl={1}ns, delay_lh={2}ns slew_hl={3}ns slew_lh={4}ns".format(self.period,
|
||||||
|
|
@ -495,10 +497,10 @@ class delay():
|
||||||
slew_lh))
|
slew_lh))
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
if not ch.relative_compare(delay_lh,feasible_delay_lh,error_tolerance=0.05):
|
if not relative_compare(delay_lh,feasible_delay_lh,error_tolerance=0.05):
|
||||||
debug.info(2,"Delay too big {0} vs {1}".format(delay_lh,feasible_delay_lh))
|
debug.info(2,"Delay too big {0} vs {1}".format(delay_lh,feasible_delay_lh))
|
||||||
return False
|
return False
|
||||||
elif not ch.relative_compare(delay_hl,feasible_delay_hl,error_tolerance=0.05):
|
elif not relative_compare(delay_hl,feasible_delay_hl,error_tolerance=0.05):
|
||||||
debug.info(2,"Delay too big {0} vs {1}".format(delay_hl,feasible_delay_hl))
|
debug.info(2,"Delay too big {0} vs {1}".format(delay_hl,feasible_delay_hl))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -602,7 +604,7 @@ class delay():
|
||||||
debug.info(1, "Min Period: {0}n with a delay of {1} / {2}".format(min_period, feasible_delay_lh, feasible_delay_hl))
|
debug.info(1, "Min Period: {0}n with a delay of {1} / {2}".format(min_period, feasible_delay_lh, feasible_delay_hl))
|
||||||
|
|
||||||
# 4) Pack up the final measurements
|
# 4) Pack up the final measurements
|
||||||
char_data["min_period"] = ch.round_time(min_period)
|
char_data["min_period"] = round_time(min_period)
|
||||||
|
|
||||||
return char_data
|
return char_data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import os,sys,re
|
import os,sys,re
|
||||||
import debug
|
import debug
|
||||||
import math
|
import math
|
||||||
import setup_hold
|
from .setup_hold import *
|
||||||
import delay
|
from .delay import *
|
||||||
import charutils as ch
|
from .charutils import *
|
||||||
import tech
|
import tech
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from globals import OPTS
|
from globals import OPTS
|
||||||
|
|
@ -186,9 +186,9 @@ class lib:
|
||||||
""" Helper function to create quoted, line wrapped array with each row of given length """
|
""" Helper function to create quoted, line wrapped array with each row of given length """
|
||||||
# check that the length is a multiple or give an error!
|
# check that the length is a multiple or give an error!
|
||||||
debug.check(len(values)%length == 0,"Values are not a multiple of the length. Cannot make a full array.")
|
debug.check(len(values)%length == 0,"Values are not a multiple of the length. Cannot make a full array.")
|
||||||
rounded_values = map(ch.round_time,values)
|
rounded_values = list(map(round_time,values))
|
||||||
split_values = [rounded_values[i:i+length] for i in range(0, len(rounded_values), length)]
|
split_values = [rounded_values[i:i+length] for i in range(0, len(rounded_values), length)]
|
||||||
formatted_rows = map(self.create_list,split_values)
|
formatted_rows = list(map(self.create_list,split_values))
|
||||||
formatted_array = ",\\\n".join(formatted_rows)
|
formatted_array = ",\\\n".join(formatted_rows)
|
||||||
return formatted_array
|
return formatted_array
|
||||||
|
|
||||||
|
|
@ -274,11 +274,11 @@ class lib:
|
||||||
self.lib.write(" timing_type : setup_rising; \n")
|
self.lib.write(" timing_type : setup_rising; \n")
|
||||||
self.lib.write(" related_pin : \"clk\"; \n")
|
self.lib.write(" related_pin : \"clk\"; \n")
|
||||||
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
|
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
|
||||||
rounded_values = map(ch.round_time,self.times["setup_times_LH"])
|
rounded_values = list(map(round_time,self.times["setup_times_LH"]))
|
||||||
self.write_values(rounded_values,len(self.slews)," ")
|
self.write_values(rounded_values,len(self.slews)," ")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
|
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
|
||||||
rounded_values = map(ch.round_time,self.times["setup_times_HL"])
|
rounded_values = list(map(round_time,self.times["setup_times_HL"]))
|
||||||
self.write_values(rounded_values,len(self.slews)," ")
|
self.write_values(rounded_values,len(self.slews)," ")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
|
|
@ -286,11 +286,11 @@ class lib:
|
||||||
self.lib.write(" timing_type : hold_rising; \n")
|
self.lib.write(" timing_type : hold_rising; \n")
|
||||||
self.lib.write(" related_pin : \"clk\"; \n")
|
self.lib.write(" related_pin : \"clk\"; \n")
|
||||||
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
|
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
|
||||||
rounded_values = map(ch.round_time,self.times["hold_times_LH"])
|
rounded_values = list(map(round_time,self.times["hold_times_LH"]))
|
||||||
self.write_values(rounded_values,len(self.slews)," ")
|
self.write_values(rounded_values,len(self.slews)," ")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
|
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
|
||||||
rounded_values = map(ch.round_time,self.times["hold_times_HL"])
|
rounded_values = list(map(round_time,self.times["hold_times_HL"]))
|
||||||
self.write_values(rounded_values,len(self.slews)," ")
|
self.write_values(rounded_values,len(self.slews)," ")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
|
|
@ -413,8 +413,8 @@ class lib:
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
self.lib.write(" }\n")
|
self.lib.write(" }\n")
|
||||||
|
|
||||||
min_pulse_width = ch.round_time(self.char_results["min_period"])/2.0
|
min_pulse_width = round_time(self.char_results["min_period"])/2.0
|
||||||
min_period = ch.round_time(self.char_results["min_period"])
|
min_period = round_time(self.char_results["min_period"])
|
||||||
self.lib.write(" timing(){ \n")
|
self.lib.write(" timing(){ \n")
|
||||||
self.lib.write(" timing_type :\"min_pulse_width\"; \n")
|
self.lib.write(" timing_type :\"min_pulse_width\"; \n")
|
||||||
self.lib.write(" related_pin : clk; \n")
|
self.lib.write(" related_pin : clk; \n")
|
||||||
|
|
@ -443,7 +443,7 @@ class lib:
|
||||||
try:
|
try:
|
||||||
self.d
|
self.d
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self.d = delay.delay(self.sram, self.sp_file, self.corner)
|
self.d = delay(self.sram, self.sp_file, self.corner)
|
||||||
if self.use_model:
|
if self.use_model:
|
||||||
self.char_results = self.d.analytical_delay(self.sram,self.slews,self.loads)
|
self.char_results = self.d.analytical_delay(self.sram,self.slews,self.loads)
|
||||||
else:
|
else:
|
||||||
|
|
@ -458,7 +458,7 @@ class lib:
|
||||||
try:
|
try:
|
||||||
self.sh
|
self.sh
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self.sh = setup_hold.setup_hold(self.corner)
|
self.sh = setup_hold(self.corner)
|
||||||
if self.use_model:
|
if self.use_model:
|
||||||
self.times = self.sh.analytical_setuphold(self.slews,self.loads)
|
self.times = self.sh.analytical_setuphold(self.slews,self.loads)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import sys
|
import sys
|
||||||
import tech
|
import tech
|
||||||
import stimuli
|
from .stimuli import *
|
||||||
import debug
|
import debug
|
||||||
import charutils as ch
|
from .charutils import *
|
||||||
import ms_flop
|
import ms_flop
|
||||||
from globals import OPTS
|
from globals import OPTS
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ class setup_hold():
|
||||||
# creates and opens the stimulus file for writing
|
# creates and opens the stimulus file for writing
|
||||||
temp_stim = OPTS.openram_temp + "stim.sp"
|
temp_stim = OPTS.openram_temp + "stim.sp"
|
||||||
self.sf = open(temp_stim, "w")
|
self.sf = open(temp_stim, "w")
|
||||||
self.stim = stimuli.stimuli(self.sf, self.corner)
|
self.stim = stimuli(self.sf, self.corner)
|
||||||
|
|
||||||
self.write_header(correct_value)
|
self.write_header(correct_value)
|
||||||
|
|
||||||
|
|
@ -186,8 +186,8 @@ class setup_hold():
|
||||||
target_time=feasible_bound,
|
target_time=feasible_bound,
|
||||||
correct_value=correct_value)
|
correct_value=correct_value)
|
||||||
self.stim.run_sim()
|
self.stim.run_sim()
|
||||||
ideal_clk_to_q = ch.convert_to_float(ch.parse_output("timing", "clk2q_delay"))
|
ideal_clk_to_q = convert_to_float(parse_output("timing", "clk2q_delay"))
|
||||||
setuphold_time = ch.convert_to_float(ch.parse_output("timing", "setup_hold_time"))
|
setuphold_time = convert_to_float(parse_output("timing", "setup_hold_time"))
|
||||||
debug.info(2,"*** {0} CHECK: {1} Ideal Clk-to-Q: {2} Setup/Hold: {3}".format(mode, correct_value,ideal_clk_to_q,setuphold_time))
|
debug.info(2,"*** {0} CHECK: {1} Ideal Clk-to-Q: {2} Setup/Hold: {3}".format(mode, correct_value,ideal_clk_to_q,setuphold_time))
|
||||||
|
|
||||||
if type(ideal_clk_to_q)!=float or type(setuphold_time)!=float:
|
if type(ideal_clk_to_q)!=float or type(setuphold_time)!=float:
|
||||||
|
|
@ -219,8 +219,8 @@ class setup_hold():
|
||||||
|
|
||||||
|
|
||||||
self.stim.run_sim()
|
self.stim.run_sim()
|
||||||
clk_to_q = ch.convert_to_float(ch.parse_output("timing", "clk2q_delay"))
|
clk_to_q = convert_to_float(parse_output("timing", "clk2q_delay"))
|
||||||
setuphold_time = ch.convert_to_float(ch.parse_output("timing", "setup_hold_time"))
|
setuphold_time = convert_to_float(parse_output("timing", "setup_hold_time"))
|
||||||
if type(clk_to_q)==float and (clk_to_q<1.1*ideal_clk_to_q) and type(setuphold_time)==float:
|
if type(clk_to_q)==float and (clk_to_q<1.1*ideal_clk_to_q) and type(setuphold_time)==float:
|
||||||
if mode == "SETUP": # SETUP is clk-din, not din-clk
|
if mode == "SETUP": # SETUP is clk-din, not din-clk
|
||||||
setuphold_time *= -1e9
|
setuphold_time *= -1e9
|
||||||
|
|
@ -235,7 +235,7 @@ class setup_hold():
|
||||||
infeasible_bound = target_time
|
infeasible_bound = target_time
|
||||||
|
|
||||||
#raw_input("Press Enter to continue...")
|
#raw_input("Press Enter to continue...")
|
||||||
if ch.relative_compare(feasible_bound, infeasible_bound, error_tolerance=0.001):
|
if relative_compare(feasible_bound, infeasible_bound, error_tolerance=0.001):
|
||||||
debug.info(3,"CONVERGE {0} vs {1}".format(feasible_bound,infeasible_bound))
|
debug.info(3,"CONVERGE {0} vs {1}".format(feasible_bound,infeasible_bound))
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@ Python GDS Mill Package
|
||||||
GDS Mill is a Python package for the creation and manipulation of binary GDS2 layout files.
|
GDS Mill is a Python package for the creation and manipulation of binary GDS2 layout files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from gds2reader import *
|
from .gds2reader import *
|
||||||
from gds2writer import *
|
from .gds2writer import *
|
||||||
from pdfLayout import *
|
#from .pdfLayout import *
|
||||||
from vlsiLayout import *
|
from .vlsiLayout import *
|
||||||
from gdsStreamer import *
|
from .gdsStreamer import *
|
||||||
from gdsPrimitives import *
|
from .gdsPrimitives import *
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import struct
|
import struct
|
||||||
from gdsPrimitives import *
|
from .gdsPrimitives import *
|
||||||
|
|
||||||
class Gds2writer:
|
class Gds2writer:
|
||||||
"""Class to take a populated layout class and write it to a file in GDSII format"""
|
"""Class to take a populated layout class and write it to a file in GDSII format"""
|
||||||
|
|
@ -14,8 +14,8 @@ class Gds2writer:
|
||||||
def print64AsBinary(self,number):
|
def print64AsBinary(self,number):
|
||||||
#debugging method for binary inspection
|
#debugging method for binary inspection
|
||||||
for index in range(0,64):
|
for index in range(0,64):
|
||||||
print (number>>(63-index))&0x1,
|
print((number>>(63-index))&0x1,eol='')
|
||||||
print "\n"
|
print("\n")
|
||||||
|
|
||||||
def ieeeDoubleFromIbmData(self,ibmData):
|
def ieeeDoubleFromIbmData(self,ibmData):
|
||||||
#the GDS double is in IBM 370 format like this:
|
#the GDS double is in IBM 370 format like this:
|
||||||
|
|
@ -40,9 +40,9 @@ class Gds2writer:
|
||||||
exponent-=1
|
exponent-=1
|
||||||
#check for underflow error -- should handle these properly!
|
#check for underflow error -- should handle these properly!
|
||||||
if(exponent<=0):
|
if(exponent<=0):
|
||||||
print "Underflow Error"
|
print("Underflow Error")
|
||||||
elif(exponent == 2047):
|
elif(exponent == 2047):
|
||||||
print "Overflow Error"
|
print("Overflow Error")
|
||||||
#re assemble
|
#re assemble
|
||||||
newFloat=(sign<<63)|(exponent<<52)|((mantissa>>12)&0xfffffffffffff)
|
newFloat=(sign<<63)|(exponent<<52)|((mantissa>>12)&0xfffffffffffff)
|
||||||
asciiDouble = struct.pack('>q',newFloat)
|
asciiDouble = struct.pack('>q',newFloat)
|
||||||
|
|
@ -84,12 +84,12 @@ class Gds2writer:
|
||||||
data = struct.unpack('>q',asciiDouble)[0]
|
data = struct.unpack('>q',asciiDouble)[0]
|
||||||
sign = data >> 63
|
sign = data >> 63
|
||||||
exponent = ((data >> 52) & 0x7ff)-1023
|
exponent = ((data >> 52) & 0x7ff)-1023
|
||||||
print exponent+1023
|
print(exponent+1023)
|
||||||
mantissa = data << 12 #chop off sign and exponent
|
mantissa = data << 12 #chop off sign and exponent
|
||||||
#self.print64AsBinary((sign<<63)|((exponent+1023)<<52)|(mantissa>>12))
|
#self.print64AsBinary((sign<<63)|((exponent+1023)<<52)|(mantissa>>12))
|
||||||
asciiDouble = struct.pack('>q',(sign<<63)|(exponent+1023<<52)|(mantissa>>12))
|
asciiDouble = struct.pack('>q',(sign<<63)|(exponent+1023<<52)|(mantissa>>12))
|
||||||
newFloat = struct.unpack('>d',asciiDouble)[0]
|
newFloat = struct.unpack('>d',asciiDouble)[0]
|
||||||
print "Check:"+str(newFloat)
|
print("Check:"+str(newFloat))
|
||||||
|
|
||||||
def writeRecord(self,record):
|
def writeRecord(self,record):
|
||||||
recordLength = len(record)+2 #make sure to include this in the length
|
recordLength = len(record)+2 #make sure to include this in the length
|
||||||
|
|
@ -99,12 +99,12 @@ class Gds2writer:
|
||||||
def writeHeader(self):
|
def writeHeader(self):
|
||||||
## Header
|
## Header
|
||||||
if("gdsVersion" in self.layoutObject.info):
|
if("gdsVersion" in self.layoutObject.info):
|
||||||
idBits='\x00\x02'
|
idBits=b'\x00\x02'
|
||||||
gdsVersion = struct.pack(">h",self.layoutObject.info["gdsVersion"])
|
gdsVersion = struct.pack(">h",self.layoutObject.info["gdsVersion"])
|
||||||
self.writeRecord(idBits+gdsVersion)
|
self.writeRecord(idBits+gdsVersion)
|
||||||
## Modified Date
|
## Modified Date
|
||||||
if("dates" in self.layoutObject.info):
|
if("dates" in self.layoutObject.info):
|
||||||
idBits='\x01\x02'
|
idBits=b'\x01\x02'
|
||||||
modYear = struct.pack(">h",self.layoutObject.info["dates"][0])
|
modYear = struct.pack(">h",self.layoutObject.info["dates"][0])
|
||||||
modMonth = struct.pack(">h",self.layoutObject.info["dates"][1])
|
modMonth = struct.pack(">h",self.layoutObject.info["dates"][1])
|
||||||
modDay = struct.pack(">h",self.layoutObject.info["dates"][2])
|
modDay = struct.pack(">h",self.layoutObject.info["dates"][2])
|
||||||
|
|
@ -122,43 +122,43 @@ class Gds2writer:
|
||||||
lastAccessMinute+lastAccessSecond)
|
lastAccessMinute+lastAccessSecond)
|
||||||
## LibraryName
|
## LibraryName
|
||||||
if("libraryName" in self.layoutObject.info):
|
if("libraryName" in self.layoutObject.info):
|
||||||
idBits='\x02\x06'
|
idBits=b'\x02\x06'
|
||||||
if (len(self.layoutObject.info["libraryName"]) % 2 != 0):
|
if (len(self.layoutObject.info["libraryName"]) % 2 != 0):
|
||||||
libraryName = self.layoutObject.info["libraryName"] + "\0"
|
libraryName = self.layoutObject.info["libraryName"].encode() + "\0"
|
||||||
else:
|
else:
|
||||||
libraryName = self.layoutObject.info["libraryName"]
|
libraryName = self.layoutObject.info["libraryName"].encode()
|
||||||
self.writeRecord(idBits+libraryName)
|
self.writeRecord(idBits+libraryName)
|
||||||
## reference libraries
|
## reference libraries
|
||||||
if("referenceLibraries" in self.layoutObject.info):
|
if("referenceLibraries" in self.layoutObject.info):
|
||||||
idBits='\x1F\x06'
|
idBits=b'\x1F\x06'
|
||||||
referenceLibraryA = self.layoutObject.info["referenceLibraries"][0]
|
referenceLibraryA = self.layoutObject.info["referenceLibraries"][0]
|
||||||
referenceLibraryB = self.layoutObject.info["referenceLibraries"][1]
|
referenceLibraryB = self.layoutObject.info["referenceLibraries"][1]
|
||||||
self.writeRecord(idBits+referenceLibraryA+referenceLibraryB)
|
self.writeRecord(idBits+referenceLibraryA+referenceLibraryB)
|
||||||
if("fonts" in self.layoutObject.info):
|
if("fonts" in self.layoutObject.info):
|
||||||
idBits='\x20\x06'
|
idBits=b'\x20\x06'
|
||||||
fontA = self.layoutObject.info["fonts"][0]
|
fontA = self.layoutObject.info["fonts"][0]
|
||||||
fontB = self.layoutObject.info["fonts"][1]
|
fontB = self.layoutObject.info["fonts"][1]
|
||||||
fontC = self.layoutObject.info["fonts"][2]
|
fontC = self.layoutObject.info["fonts"][2]
|
||||||
fontD = self.layoutObject.info["fonts"][3]
|
fontD = self.layoutObject.info["fonts"][3]
|
||||||
self.writeRecord(idBits+fontA+fontB+fontC+fontD)
|
self.writeRecord(idBits+fontA+fontB+fontC+fontD)
|
||||||
if("attributeTable" in self.layoutObject.info):
|
if("attributeTable" in self.layoutObject.info):
|
||||||
idBits='\x23\x06'
|
idBits=b'\x23\x06'
|
||||||
attributeTable = self.layoutObject.info["attributeTable"]
|
attributeTable = self.layoutObject.info["attributeTable"]
|
||||||
self.writeRecord(idBits+attributeTable)
|
self.writeRecord(idBits+attributeTable)
|
||||||
if("generations" in self.layoutObject.info):
|
if("generations" in self.layoutObject.info):
|
||||||
idBits='\x22\x02'
|
idBits=b'\x22\x02'
|
||||||
generations = struct.pack(">h",self.layoutObject.info["generations"])
|
generations = struct.pack(">h",self.layoutObject.info["generations"])
|
||||||
self.writeRecord(idBits+generations)
|
self.writeRecord(idBits+generations)
|
||||||
if("fileFormat" in self.layoutObject.info):
|
if("fileFormat" in self.layoutObject.info):
|
||||||
idBits='\x36\x02'
|
idBits=b'\x36\x02'
|
||||||
fileFormat = struct.pack(">h",self.layoutObject.info["fileFormat"])
|
fileFormat = struct.pack(">h",self.layoutObject.info["fileFormat"])
|
||||||
self.writeRecord(idBits+fileFormat)
|
self.writeRecord(idBits+fileFormat)
|
||||||
if("mask" in self.layoutObject.info):
|
if("mask" in self.layoutObject.info):
|
||||||
idBits='\x37\x06'
|
idBits=b'\x37\x06'
|
||||||
mask = self.layoutObject.info["mask"]
|
mask = self.layoutObject.info["mask"]
|
||||||
self.writeRecord(idBits+mask)
|
self.writeRecord(idBits+mask)
|
||||||
if("units" in self.layoutObject.info):
|
if("units" in self.layoutObject.info):
|
||||||
idBits='\x03\x05'
|
idBits=b'\x03\x05'
|
||||||
userUnits=self.ibmDataFromIeeeDouble(self.layoutObject.info["units"][0])
|
userUnits=self.ibmDataFromIeeeDouble(self.layoutObject.info["units"][0])
|
||||||
dbUnits=self.ibmDataFromIeeeDouble((self.layoutObject.info["units"][0]*1e-6/self.layoutObject.info["units"][1])*self.layoutObject.info["units"][1])
|
dbUnits=self.ibmDataFromIeeeDouble((self.layoutObject.info["units"][0]*1e-6/self.layoutObject.info["units"][1])*self.layoutObject.info["units"][1])
|
||||||
|
|
||||||
|
|
@ -176,171 +176,171 @@ class Gds2writer:
|
||||||
|
|
||||||
self.writeRecord(idBits+userUnits+dbUnits)
|
self.writeRecord(idBits+userUnits+dbUnits)
|
||||||
if(self.debugToTerminal==1):
|
if(self.debugToTerminal==1):
|
||||||
print "writer: userUnits %s"%(userUnits.encode("hex"))
|
print("writer: userUnits %s"%(userUnits.encode("hex")))
|
||||||
print "writer: dbUnits %s"%(dbUnits.encode("hex"))
|
print("writer: dbUnits %s"%(dbUnits.encode("hex")))
|
||||||
#self.ieeeFloatCheck(1.3e-6)
|
#self.ieeeFloatCheck(1.3e-6)
|
||||||
|
|
||||||
print "End of GDSII Header Written"
|
print("End of GDSII Header Written")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def writeBoundary(self,thisBoundary):
|
def writeBoundary(self,thisBoundary):
|
||||||
idBits = '\x08\x00' #record Type
|
idBits=b'\x08\x00' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisBoundary.elementFlags!=""):
|
if(thisBoundary.elementFlags!=""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisBoundary.elementFlags)
|
elementFlags = struct.pack(">h",thisBoundary.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisBoundary.plex!=""):
|
if(thisBoundary.plex!=""):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisBoundary.plex)
|
plex = struct.pack(">i",thisBoundary.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisBoundary.drawingLayer!=""):
|
if(thisBoundary.drawingLayer!=""):
|
||||||
idBits='\x0D\x02' #drawig layer
|
idBits=b'\x0D\x02' #drawig layer
|
||||||
drawingLayer = struct.pack(">h",thisBoundary.drawingLayer)
|
drawingLayer = struct.pack(">h",thisBoundary.drawingLayer)
|
||||||
self.writeRecord(idBits+drawingLayer)
|
self.writeRecord(idBits+drawingLayer)
|
||||||
if(thisBoundary.purposeLayer):
|
if(thisBoundary.purposeLayer):
|
||||||
idBits='\x16\x02' #purpose layer
|
idBits=b'\x16\x02' #purpose layer
|
||||||
purposeLayer = struct.pack(">h",thisBoundary.purposeLayer)
|
purposeLayer = struct.pack(">h",thisBoundary.purposeLayer)
|
||||||
self.writeRecord(idBits+purposeLayer)
|
self.writeRecord(idBits+purposeLayer)
|
||||||
if(thisBoundary.dataType!=""):
|
if(thisBoundary.dataType!=""):
|
||||||
idBits='\x0E\x02'#DataType
|
idBits=b'\x0E\x02'#DataType
|
||||||
dataType = struct.pack(">h",thisBoundary.dataType)
|
dataType = struct.pack(">h",thisBoundary.dataType)
|
||||||
self.writeRecord(idBits+dataType)
|
self.writeRecord(idBits+dataType)
|
||||||
if(thisBoundary.coordinates!=""):
|
if(thisBoundary.coordinates!=""):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
for coordinate in thisBoundary.coordinates:
|
for coordinate in thisBoundary.coordinates:
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",int(coordinate[0]))
|
||||||
y=struct.pack(">i",coordinate[1])
|
y=struct.pack(">i",int(coordinate[1]))
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writePath(self,thisPath): #writes out a path structure
|
def writePath(self,thisPath): #writes out a path structure
|
||||||
idBits = '\x09\x00' #record Type
|
idBits=b'\x09\x00' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisPath.elementFlags != ""):
|
if(thisPath.elementFlags != ""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisPath.elementFlags)
|
elementFlags = struct.pack(">h",thisPath.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisPath.plex!=""):
|
if(thisPath.plex!=""):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisPath.plex)
|
plex = struct.pack(">i",thisPath.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisPath.drawingLayer):
|
if(thisPath.drawingLayer):
|
||||||
idBits='\x0D\x02' #drawig layer
|
idBits=b'\x0D\x02' #drawig layer
|
||||||
drawingLayer = struct.pack(">h",thisPath.drawingLayer)
|
drawingLayer = struct.pack(">h",thisPath.drawingLayer)
|
||||||
self.writeRecord(idBits+drawingLayer)
|
self.writeRecord(idBits+drawingLayer)
|
||||||
if(thisPath.purposeLayer):
|
if(thisPath.purposeLayer):
|
||||||
idBits='\x16\x02' #purpose layer
|
idBits=b'\x16\x02' #purpose layer
|
||||||
purposeLayer = struct.pack(">h",thisPath.purposeLayer)
|
purposeLayer = struct.pack(">h",thisPath.purposeLayer)
|
||||||
self.writeRecord(idBits+purposeLayer)
|
self.writeRecord(idBits+purposeLayer)
|
||||||
if(thisPath.pathType):
|
if(thisPath.pathType):
|
||||||
idBits='\x21\x02' #Path type
|
idBits=b'\x21\x02' #Path type
|
||||||
pathType = struct.pack(">h",thisPath.pathType)
|
pathType = struct.pack(">h",thisPath.pathType)
|
||||||
self.writeRecord(idBits+pathType)
|
self.writeRecord(idBits+pathType)
|
||||||
if(thisPath.pathWidth):
|
if(thisPath.pathWidth):
|
||||||
idBits='\x0F\x03'
|
idBits=b'\x0F\x03'
|
||||||
pathWidth = struct.pack(">i",thisPath.pathWidth)
|
pathWidth = struct.pack(">i",thisPath.pathWidth)
|
||||||
self.writeRecord(idBits+pathWidth)
|
self.writeRecord(idBits+pathWidth)
|
||||||
if(thisPath.coordinates):
|
if(thisPath.coordinates):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
for coordinate in thisPath.coordinates:
|
for coordinate in thisPath.coordinates:
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",int(coordinate[0]))
|
||||||
y=struct.pack(">i",coordinate[1])
|
y=struct.pack(">i",int(coordinate[1]))
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writeSref(self,thisSref): #reads in a reference to another structure
|
def writeSref(self,thisSref): #reads in a reference to another structure
|
||||||
idBits = '\x0A\x00' #record Type
|
idBits=b'\x0A\x00' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisSref.elementFlags != ""):
|
if(thisSref.elementFlags != ""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisSref.elementFlags)
|
elementFlags = struct.pack(">h",thisSref.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisSref.plex!=""):
|
if(thisSref.plex!=""):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisSref.plex)
|
plex = struct.pack(">i",thisSref.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisSref.sName!=""):
|
if(thisSref.sName!=""):
|
||||||
idBits='\x12\x06'
|
idBits=b'\x12\x06'
|
||||||
if (len(thisSref.sName) % 2 != 0):
|
if (len(thisSref.sName) % 2 != 0):
|
||||||
sName = thisSref.sName+"\0"
|
sName = thisSref.sName+"\0"
|
||||||
else:
|
else:
|
||||||
sName = thisSref.sName
|
sName = thisSref.sName
|
||||||
self.writeRecord(idBits+sName)
|
self.writeRecord(idBits+sName.encode())
|
||||||
if(thisSref.transFlags!=""):
|
if(thisSref.transFlags!=""):
|
||||||
idBits='\x1A\x01'
|
idBits=b'\x1A\x01'
|
||||||
mirrorFlag = int(thisSref.transFlags[0])<<15
|
mirrorFlag = int(thisSref.transFlags[0])<<15
|
||||||
rotateFlag = int(thisSref.transFlags[1])<<1
|
rotateFlag = int(thisSref.transFlags[1])<<1
|
||||||
magnifyFlag = int(thisSref.transFlags[2])<<3
|
magnifyFlag = int(thisSref.transFlags[2])<<3
|
||||||
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
||||||
self.writeRecord(idBits+transFlags)
|
self.writeRecord(idBits+transFlags)
|
||||||
if(thisSref.magFactor!=""):
|
if(thisSref.magFactor!=""):
|
||||||
idBits='\x1B\x05'
|
idBits=b'\x1B\x05'
|
||||||
magFactor=self.ibmDataFromIeeeDouble(thisSref.magFactor)
|
magFactor=self.ibmDataFromIeeeDouble(thisSref.magFactor)
|
||||||
self.writeRecord(idBits+magFactor)
|
self.writeRecord(idBits+magFactor)
|
||||||
if(thisSref.rotateAngle!=""):
|
if(thisSref.rotateAngle!=""):
|
||||||
idBits='\x1C\x05'
|
idBits=b'\x1C\x05'
|
||||||
rotateAngle=self.ibmDataFromIeeeDouble(thisSref.rotateAngle)
|
rotateAngle=self.ibmDataFromIeeeDouble(thisSref.rotateAngle)
|
||||||
self.writeRecord(idBits+rotateAngle)
|
self.writeRecord(idBits+rotateAngle)
|
||||||
if(thisSref.coordinates!=""):
|
if(thisSref.coordinates!=""):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
coordinate = thisSref.coordinates
|
coordinate = thisSref.coordinates
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",int(coordinate[0]))
|
||||||
y=struct.pack(">i",coordinate[1])
|
y=struct.pack(">i",int(coordinate[1]))
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
#print thisSref.coordinates
|
#print(thisSref.coordinates)
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writeAref(self,thisAref): #an array of references
|
def writeAref(self,thisAref): #an array of references
|
||||||
idBits = '\x0B\x00' #record Type
|
idBits=b'\x0B\x00' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisAref.elementFlags!=""):
|
if(thisAref.elementFlags!=""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisAref.elementFlags)
|
elementFlags = struct.pack(">h",thisAref.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisAref.plex):
|
if(thisAref.plex):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisAref.plex)
|
plex = struct.pack(">i",thisAref.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisAref.aName):
|
if(thisAref.aName):
|
||||||
idBits='\x12\x06'
|
idBits=b'\x12\x06'
|
||||||
if (len(thisAref.aName) % 2 != 0):
|
if (len(thisAref.aName) % 2 != 0):
|
||||||
aName = thisAref.aName+"\0"
|
aName = thisAref.aName+"\0"
|
||||||
else:
|
else:
|
||||||
aName = thisAref.aName
|
aName = thisAref.aName
|
||||||
self.writeRecord(idBits+aName)
|
self.writeRecord(idBits+aName)
|
||||||
if(thisAref.transFlags):
|
if(thisAref.transFlags):
|
||||||
idBits='\x1A\x01'
|
idBits=b'\x1A\x01'
|
||||||
mirrorFlag = int(thisAref.transFlags[0])<<15
|
mirrorFlag = int(thisAref.transFlags[0])<<15
|
||||||
rotateFlag = int(thisAref.transFlags[1])<<1
|
rotateFlag = int(thisAref.transFlags[1])<<1
|
||||||
magnifyFlag = int(thisAref.transFlags[0])<<3
|
magnifyFlag = int(thisAref.transFlags[0])<<3
|
||||||
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
||||||
self.writeRecord(idBits+transFlags)
|
self.writeRecord(idBits+transFlags)
|
||||||
if(thisAref.magFactor):
|
if(thisAref.magFactor):
|
||||||
idBits='\x1B\x05'
|
idBits=b'\x1B\x05'
|
||||||
magFactor=self.ibmDataFromIeeeDouble(thisAref.magFactor)
|
magFactor=self.ibmDataFromIeeeDouble(thisAref.magFactor)
|
||||||
self.writeRecord(idBits+magFactor)
|
self.writeRecord(idBits+magFactor)
|
||||||
if(thisAref.rotateAngle):
|
if(thisAref.rotateAngle):
|
||||||
idBits='\x1C\x05'
|
idBits=b'\x1C\x05'
|
||||||
rotateAngle=self.ibmDataFromIeeeDouble(thisAref.rotateAngle)
|
rotateAngle=self.ibmDataFromIeeeDouble(thisAref.rotateAngle)
|
||||||
self.writeRecord(idBits+rotateAngle)
|
self.writeRecord(idBits+rotateAngle)
|
||||||
if(thisAref.coordinates):
|
if(thisAref.coordinates):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
for coordinate in thisAref.coordinates:
|
for coordinate in thisAref.coordinates:
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",coordinate[0])
|
||||||
|
|
@ -348,151 +348,151 @@ class Gds2writer:
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writeText(self,thisText):
|
def writeText(self,thisText):
|
||||||
idBits = '\x0C\x00' #record Type
|
idBits=b'\x0C\x00' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisText.elementFlags!=""):
|
if(thisText.elementFlags!=""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisText.elementFlags)
|
elementFlags = struct.pack(">h",thisText.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisText.plex !=""):
|
if(thisText.plex !=""):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisText.plex)
|
plex = struct.pack(">i",thisText.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisText.drawingLayer != ""):
|
if(thisText.drawingLayer != ""):
|
||||||
idBits='\x0D\x02' #drawing layer
|
idBits=b'\x0D\x02' #drawing layer
|
||||||
drawingLayer = struct.pack(">h",thisText.drawingLayer)
|
drawingLayer = struct.pack(">h",thisText.drawingLayer)
|
||||||
self.writeRecord(idBits+drawingLayer)
|
self.writeRecord(idBits+drawingLayer)
|
||||||
#if(thisText.purposeLayer):
|
#if(thisText.purposeLayer):
|
||||||
idBits='\x16\x02' #purpose layer
|
idBits=b'\x16\x02' #purpose layer
|
||||||
purposeLayer = struct.pack(">h",thisText.purposeLayer)
|
purposeLayer = struct.pack(">h",thisText.purposeLayer)
|
||||||
self.writeRecord(idBits+purposeLayer)
|
self.writeRecord(idBits+purposeLayer)
|
||||||
if(thisText.transFlags != ""):
|
if(thisText.transFlags != ""):
|
||||||
idBits='\x1A\x01'
|
idBits=b'\x1A\x01'
|
||||||
mirrorFlag = int(thisText.transFlags[0])<<15
|
mirrorFlag = int(thisText.transFlags[0])<<15
|
||||||
rotateFlag = int(thisText.transFlags[1])<<1
|
rotateFlag = int(thisText.transFlags[1])<<1
|
||||||
magnifyFlag = int(thisText.transFlags[0])<<3
|
magnifyFlag = int(thisText.transFlags[0])<<3
|
||||||
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
|
||||||
self.writeRecord(idBits+transFlags)
|
self.writeRecord(idBits+transFlags)
|
||||||
if(thisText.magFactor != ""):
|
if(thisText.magFactor != ""):
|
||||||
idBits='\x1B\x05'
|
idBits=b'\x1B\x05'
|
||||||
magFactor=self.ibmDataFromIeeeDouble(thisText.magFactor)
|
magFactor=self.ibmDataFromIeeeDouble(thisText.magFactor)
|
||||||
self.writeRecord(idBits+magFactor)
|
self.writeRecord(idBits+magFactor)
|
||||||
if(thisText.rotateAngle != ""):
|
if(thisText.rotateAngle != ""):
|
||||||
idBits='\x1C\x05'
|
idBits=b'\x1C\x05'
|
||||||
rotateAngle=self.ibmDataFromIeeeDouble(thisText.rotateAngle)
|
rotateAngle=self.ibmDataFromIeeeDouble(thisText.rotateAngle)
|
||||||
self.writeRecord(idBits+rotateAngle)
|
self.writeRecord(idBits+rotateAngle)
|
||||||
if(thisText.pathType !=""):
|
if(thisText.pathType !=""):
|
||||||
idBits='\x21\x02' #Path type
|
idBits=b'\x21\x02' #Path type
|
||||||
pathType = struct.pack(">h",thisText.pathType)
|
pathType = struct.pack(">h",thisText.pathType)
|
||||||
self.writeRecord(idBits+pathType)
|
self.writeRecord(idBits+pathType)
|
||||||
if(thisText.pathWidth != ""):
|
if(thisText.pathWidth != ""):
|
||||||
idBits='\x0F\x03'
|
idBits=b'\x0F\x03'
|
||||||
pathWidth = struct.pack(">i",thisText.pathWidth)
|
pathWidth = struct.pack(">i",thisText.pathWidth)
|
||||||
self.writeRecord(idBits+pathWidth)
|
self.writeRecord(idBits+pathWidth)
|
||||||
if(thisText.presentationFlags!=""):
|
if(thisText.presentationFlags!=""):
|
||||||
idBits='\x1A\x01'
|
idBits=b'\x1A\x01'
|
||||||
font = thisText.presentationFlags[0]<<4
|
font = thisText.presentationFlags[0]<<4
|
||||||
verticalFlags = int(thisText.presentationFlags[1])<<2
|
verticalFlags = int(thisText.presentationFlags[1])<<2
|
||||||
horizontalFlags = int(thisText.presentationFlags[2])
|
horizontalFlags = int(thisText.presentationFlags[2])
|
||||||
presentationFlags = struct.pack(">H",font|verticalFlags|horizontalFlags)
|
presentationFlags = struct.pack(">H",font|verticalFlags|horizontalFlags)
|
||||||
self.writeRecord(idBits+transFlags)
|
self.writeRecord(idBits+transFlags)
|
||||||
if(thisText.coordinates!=""):
|
if(thisText.coordinates!=""):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
for coordinate in thisText.coordinates:
|
for coordinate in thisText.coordinates:
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",int(coordinate[0]))
|
||||||
y=struct.pack(">i",coordinate[1])
|
y=struct.pack(">i",int(coordinate[1]))
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
if(thisText.textString):
|
if(thisText.textString):
|
||||||
idBits='\x19\x06'
|
idBits=b'\x19\x06'
|
||||||
textString = thisText.textString
|
textString = thisText.textString
|
||||||
self.writeRecord(idBits+textString)
|
self.writeRecord(idBits+textString.encode())
|
||||||
|
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writeNode(self,thisNode):
|
def writeNode(self,thisNode):
|
||||||
idBits = '\x15\x00' #record Type
|
idBits=b'\x15\x00' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisNode.elementFlags!=""):
|
if(thisNode.elementFlags!=""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisNode.elementFlags)
|
elementFlags = struct.pack(">h",thisNode.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisNode.plex!=""):
|
if(thisNode.plex!=""):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisNode.plex)
|
plex = struct.pack(">i",thisNode.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisNode.drawingLayer!=""):
|
if(thisNode.drawingLayer!=""):
|
||||||
idBits='\x0D\x02' #drawig layer
|
idBits=b'\x0D\x02' #drawig layer
|
||||||
drawingLayer = struct.pack(">h",thisNode.drawingLayer)
|
drawingLayer = struct.pack(">h",thisNode.drawingLayer)
|
||||||
self.writeRecord(idBits+drawingLayer)
|
self.writeRecord(idBits+drawingLayer)
|
||||||
if(thisNode.nodeType!=""):
|
if(thisNode.nodeType!=""):
|
||||||
idBits='\x2A\x02'
|
idBits=b'\x2A\x02'
|
||||||
nodeType = struct.pack(">h",thisNode.nodeType)
|
nodeType = struct.pack(">h",thisNode.nodeType)
|
||||||
self.writeRecord(idBits+nodeType)
|
self.writeRecord(idBits+nodeType)
|
||||||
if(thisText.coordinates!=""):
|
if(thisText.coordinates!=""):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
for coordinate in thisText.coordinates:
|
for coordinate in thisText.coordinates:
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",int(coordinate[0]))
|
||||||
y=struct.pack(">i",coordinate[1])
|
y=struct.pack(">i",int(coordinate[1]))
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writeBox(self,thisBox):
|
def writeBox(self,thisBox):
|
||||||
idBits = '\x2E\x02' #record Type
|
idBits=b'\x2E\x02' #record Type
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
if(thisBox.elementFlags!=""):
|
if(thisBox.elementFlags!=""):
|
||||||
idBits='\x26\x01' #ELFLAGS
|
idBits=b'\x26\x01' #ELFLAGS
|
||||||
elementFlags = struct.pack(">h",thisBox.elementFlags)
|
elementFlags = struct.pack(">h",thisBox.elementFlags)
|
||||||
self.writeRecord(idBits+elementFlags)
|
self.writeRecord(idBits+elementFlags)
|
||||||
if(thisBox.plex!=""):
|
if(thisBox.plex!=""):
|
||||||
idBits='\x2F\x03' #PLEX
|
idBits=b'\x2F\x03' #PLEX
|
||||||
plex = struct.pack(">i",thisBox.plex)
|
plex = struct.pack(">i",thisBox.plex)
|
||||||
self.writeRecord(idBits+plex)
|
self.writeRecord(idBits+plex)
|
||||||
if(thisBox.drawingLayer!=""):
|
if(thisBox.drawingLayer!=""):
|
||||||
idBits='\x0D\x02' #drawig layer
|
idBits=b'\x0D\x02' #drawig layer
|
||||||
drawingLayer = struct.pack(">h",thisBox.drawingLayer)
|
drawingLayer = struct.pack(">h",thisBox.drawingLayer)
|
||||||
self.writeRecord(idBits+drawingLayer)
|
self.writeRecord(idBits+drawingLayer)
|
||||||
if(thisBox.purposeLayer):
|
if(thisBox.purposeLayer):
|
||||||
idBits='\x16\x02' #purpose layer
|
idBits=b'\x16\x02' #purpose layer
|
||||||
purposeLayer = struct.pack(">h",thisBox.purposeLayer)
|
purposeLayer = struct.pack(">h",thisBox.purposeLayer)
|
||||||
self.writeRecord(idBits+purposeLayer)
|
self.writeRecord(idBits+purposeLayer)
|
||||||
if(thisBox.boxValue!=""):
|
if(thisBox.boxValue!=""):
|
||||||
idBits='\x2D\x00'
|
idBits=b'\x2D\x00'
|
||||||
boxValue = struct.pack(">h",thisBox.boxValue)
|
boxValue = struct.pack(">h",thisBox.boxValue)
|
||||||
self.writeRecord(idBits+boxValue)
|
self.writeRecord(idBits+boxValue)
|
||||||
if(thisBox.coordinates!=""):
|
if(thisBox.coordinates!=""):
|
||||||
idBits='\x10\x03' #XY Data Points
|
idBits=b'\x10\x03' #XY Data Points
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
for coordinate in thisBox.coordinates:
|
for coordinate in thisBox.coordinates:
|
||||||
x=struct.pack(">i",coordinate[0])
|
x=struct.pack(">i",int(coordinate[0]))
|
||||||
y=struct.pack(">i",coordinate[1])
|
y=struct.pack(">i",int(coordinate[1]))
|
||||||
coordinateRecord+=x
|
coordinateRecord+=x
|
||||||
coordinateRecord+=y
|
coordinateRecord+=y
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
idBits='\x11\x00' #End Of Element
|
idBits=b'\x11\x00' #End Of Element
|
||||||
coordinateRecord = idBits
|
coordinateRecord = idBits
|
||||||
self.writeRecord(coordinateRecord)
|
self.writeRecord(coordinateRecord)
|
||||||
|
|
||||||
def writeNextStructure(self,structureName):
|
def writeNextStructure(self,structureName):
|
||||||
#first put in the structure head
|
#first put in the structure head
|
||||||
thisStructure = self.layoutObject.structures[structureName]
|
thisStructure = self.layoutObject.structures[structureName]
|
||||||
idBits='\x05\x02'
|
idBits=b'\x05\x02'
|
||||||
createYear = struct.pack(">h",thisStructure.createDate[0])
|
createYear = struct.pack(">h",thisStructure.createDate[0])
|
||||||
createMonth = struct.pack(">h",thisStructure.createDate[1])
|
createMonth = struct.pack(">h",thisStructure.createDate[1])
|
||||||
createDay = struct.pack(">h",thisStructure.createDate[2])
|
createDay = struct.pack(">h",thisStructure.createDate[2])
|
||||||
|
|
@ -508,12 +508,12 @@ class Gds2writer:
|
||||||
self.writeRecord(idBits+createYear+createMonth+createDay+createHour+createMinute+createSecond\
|
self.writeRecord(idBits+createYear+createMonth+createDay+createHour+createMinute+createSecond\
|
||||||
+modYear+modMonth+modDay+modHour+modMinute+modSecond)
|
+modYear+modMonth+modDay+modHour+modMinute+modSecond)
|
||||||
#now the structure name
|
#now the structure name
|
||||||
idBits='\x06\x06'
|
idBits=b'\x06\x06'
|
||||||
##caveat: the name needs to be an EVEN number of characters
|
##caveat: the name needs to be an EVEN number of characters
|
||||||
if(len(structureName)%2 == 1):
|
if(len(structureName)%2 == 1):
|
||||||
#pad with a zero
|
#pad with a zero
|
||||||
structureName = structureName + '\x00'
|
structureName = structureName + '\x00'
|
||||||
self.writeRecord(idBits+structureName)
|
self.writeRecord(idBits+structureName.encode())
|
||||||
#now go through all the structure elements and write them in
|
#now go through all the structure elements and write them in
|
||||||
|
|
||||||
for boundary in thisStructure.boundaries:
|
for boundary in thisStructure.boundaries:
|
||||||
|
|
@ -531,7 +531,7 @@ class Gds2writer:
|
||||||
for box in thisStructure.boxes:
|
for box in thisStructure.boxes:
|
||||||
self.writeBox(box)
|
self.writeBox(box)
|
||||||
#put in the structure tail
|
#put in the structure tail
|
||||||
idBits='\x07\x00'
|
idBits=b'\x07\x00'
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
|
|
||||||
def writeGds2(self):
|
def writeGds2(self):
|
||||||
|
|
@ -540,7 +540,7 @@ class Gds2writer:
|
||||||
for structureName in self.layoutObject.structures:
|
for structureName in self.layoutObject.structures:
|
||||||
self.writeNextStructure(structureName)
|
self.writeNextStructure(structureName)
|
||||||
#at the end, put in the END LIB record
|
#at the end, put in the END LIB record
|
||||||
idBits='\x04\x00'
|
idBits=b'\x04\x00'
|
||||||
self.writeRecord(idBits)
|
self.writeRecord(idBits)
|
||||||
|
|
||||||
def writeToFile(self,fileName):
|
def writeToFile(self,fileName):
|
||||||
|
|
|
||||||
|
|
@ -122,11 +122,11 @@ class GdsStreamer:
|
||||||
#stream the gds out from cadence
|
#stream the gds out from cadence
|
||||||
worker = os.popen("pipo strmout "+self.workingDirectory+"/partStreamOut.tmpl")
|
worker = os.popen("pipo strmout "+self.workingDirectory+"/partStreamOut.tmpl")
|
||||||
#dump the outputs to the screen line by line
|
#dump the outputs to the screen line by line
|
||||||
print "Streaming Out From Cadence......"
|
print("Streaming Out From Cadence......")
|
||||||
while 1:
|
while 1:
|
||||||
line = worker.readline()
|
line = worker.readline()
|
||||||
if not line: break #this means sim is finished so jump out
|
if not line: break #this means sim is finished so jump out
|
||||||
#else: print line #for debug only
|
#else: print(line) #for debug only
|
||||||
worker.close()
|
worker.close()
|
||||||
#now remove the template file
|
#now remove the template file
|
||||||
os.remove(self.workingDirectory+"/partStreamOut.tmpl")
|
os.remove(self.workingDirectory+"/partStreamOut.tmpl")
|
||||||
|
|
@ -142,13 +142,13 @@ class GdsStreamer:
|
||||||
#stream the gds out from cadence
|
#stream the gds out from cadence
|
||||||
worker = os.popen("pipo strmin "+self.workingDirectory+"/partStreamIn.tmpl")
|
worker = os.popen("pipo strmin "+self.workingDirectory+"/partStreamIn.tmpl")
|
||||||
#dump the outputs to the screen line by line
|
#dump the outputs to the screen line by line
|
||||||
print "Streaming In To Cadence......"
|
print("Streaming In To Cadence......")
|
||||||
while 1:
|
while 1:
|
||||||
line = worker.readline()
|
line = worker.readline()
|
||||||
if not line: break #this means sim is finished so jump out
|
if not line: break #this means sim is finished so jump out
|
||||||
#else: print line #for debug only
|
#else: print(line) #for debug only
|
||||||
worker.close()
|
worker.close()
|
||||||
#now remove the template file
|
#now remove the template file
|
||||||
os.remove(self.workingDirectory+"/partStreamIn.tmpl")
|
os.remove(self.workingDirectory+"/partStreamIn.tmpl")
|
||||||
#and go back to whever it was we started from
|
#and go back to whever it was we started from
|
||||||
os.chdir(currentPath)
|
os.chdir(currentPath)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pyx
|
import pyx
|
||||||
import math
|
import math
|
||||||
import mpmath
|
from numpy import matrix
|
||||||
from gdsPrimitives import *
|
from gdsPrimitives import *
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
@ -39,12 +39,12 @@ class pdfLayout:
|
||||||
"""
|
"""
|
||||||
xyCoordinates = []
|
xyCoordinates = []
|
||||||
#setup a translation matrix
|
#setup a translation matrix
|
||||||
tMatrix = mpmath.matrix([[1.0,0.0,origin[0]],[0.0,1.0,origin[1]],[0.0,0.0,1.0]])
|
tMatrix = matrix([[1.0,0.0,origin[0]],[0.0,1.0,origin[1]],[0.0,0.0,1.0]])
|
||||||
#and a rotation matrix
|
#and a rotation matrix
|
||||||
rMatrix = mpmath.matrix([[uVector[0],vVector[0],0.0],[uVector[1],vVector[1],0.0],[0.0,0.0,1.0]])
|
rMatrix = matrix([[uVector[0],vVector[0],0.0],[uVector[1],vVector[1],0.0],[0.0,0.0,1.0]])
|
||||||
for coordinate in uvCoordinates:
|
for coordinate in uvCoordinates:
|
||||||
#grab the point in UV space
|
#grab the point in UV space
|
||||||
uvPoint = mpmath.matrix([coordinate[0],coordinate[1],1.0])
|
uvPoint = matrix([coordinate[0],coordinate[1],1.0])
|
||||||
#now rotate and translate it back to XY space
|
#now rotate and translate it back to XY space
|
||||||
xyPoint = rMatrix * uvPoint
|
xyPoint = rMatrix * uvPoint
|
||||||
xyPoint = tMatrix * xyPoint
|
xyPoint = tMatrix * xyPoint
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from gdsPrimitives import *
|
from .gdsPrimitives import *
|
||||||
from datetime import *
|
from datetime import *
|
||||||
import mpmath
|
#from mpmath import matrix
|
||||||
import gdsPrimitives
|
from numpy import matrix
|
||||||
|
#import gdsPrimitives
|
||||||
import debug
|
import debug
|
||||||
|
|
||||||
class VlsiLayout:
|
class VlsiLayout:
|
||||||
|
|
@ -10,7 +11,7 @@ class VlsiLayout:
|
||||||
def __init__(self, name=None, units=(0.001,1e-9), libraryName = "DEFAULT.DB", gdsVersion=5):
|
def __init__(self, name=None, units=(0.001,1e-9), libraryName = "DEFAULT.DB", gdsVersion=5):
|
||||||
#keep a list of all the structures in this layout
|
#keep a list of all the structures in this layout
|
||||||
self.units = units
|
self.units = units
|
||||||
#print units
|
#print(units)
|
||||||
modDate = datetime.now()
|
modDate = datetime.now()
|
||||||
self.structures=dict()
|
self.structures=dict()
|
||||||
self.layerNumbersInUse = []
|
self.layerNumbersInUse = []
|
||||||
|
|
@ -89,7 +90,7 @@ class VlsiLayout:
|
||||||
|
|
||||||
def newLayout(self,newName):
|
def newLayout(self,newName):
|
||||||
#if (newName == "" | newName == 0):
|
#if (newName == "" | newName == 0):
|
||||||
# print("ERROR: vlsiLayout.py:newLayout newName is null")
|
# print("ERROR: vlsiLayout.py:newLayout newName is null")
|
||||||
|
|
||||||
#make sure the newName is a multiple of 2 characters
|
#make sure the newName is a multiple of 2 characters
|
||||||
#if(len(newName)%2 == 1):
|
#if(len(newName)%2 == 1):
|
||||||
|
|
@ -134,13 +135,12 @@ class VlsiLayout:
|
||||||
self.populateCoordinateMap()
|
self.populateCoordinateMap()
|
||||||
|
|
||||||
def deduceHierarchy(self):
|
def deduceHierarchy(self):
|
||||||
#first, find the root of the tree.
|
""" First, find the root of the tree.
|
||||||
#go through and get the name of every structure.
|
Then go through and get the name of every structure.
|
||||||
#then, go through and find which structure is not
|
Then, go through and find which structure is not
|
||||||
#contained by any other structure. this is the root.
|
contained by any other structure. this is the root."""
|
||||||
structureNames=[]
|
structureNames=[]
|
||||||
for name in self.structures:
|
for name in self.structures:
|
||||||
#print "deduceHierarchy: structure.name[%s]",name //FIXME: Added By Tom G.
|
|
||||||
structureNames+=[name]
|
structureNames+=[name]
|
||||||
|
|
||||||
for name in self.structures:
|
for name in self.structures:
|
||||||
|
|
@ -148,7 +148,7 @@ class VlsiLayout:
|
||||||
for sref in self.structures[name].srefs: #go through each reference
|
for sref in self.structures[name].srefs: #go through each reference
|
||||||
if sref.sName in structureNames: #and compare to our list
|
if sref.sName in structureNames: #and compare to our list
|
||||||
structureNames.remove(sref.sName)
|
structureNames.remove(sref.sName)
|
||||||
|
|
||||||
self.rootStructureName = structureNames[0]
|
self.rootStructureName = structureNames[0]
|
||||||
|
|
||||||
def traverseTheHierarchy(self, startingStructureName=None, delegateFunction = None,
|
def traverseTheHierarchy(self, startingStructureName=None, delegateFunction = None,
|
||||||
|
|
@ -163,19 +163,20 @@ class VlsiLayout:
|
||||||
rotateAngle = 0
|
rotateAngle = 0
|
||||||
else:
|
else:
|
||||||
rotateAngle = math.radians(float(rotateAngle))
|
rotateAngle = math.radians(float(rotateAngle))
|
||||||
mRotate = mpmath.matrix([[math.cos(rotateAngle),-math.sin(rotateAngle),0.0],
|
mRotate = matrix([[math.cos(rotateAngle),-math.sin(rotateAngle),0.0],
|
||||||
[math.sin(rotateAngle),math.cos(rotateAngle),0.0],[0.0,0.0,1.0],])
|
[math.sin(rotateAngle),math.cos(rotateAngle),0.0],
|
||||||
|
[0.0,0.0,1.0]])
|
||||||
#set up the translation matrix
|
#set up the translation matrix
|
||||||
translateX = float(coordinates[0])
|
translateX = float(coordinates[0])
|
||||||
translateY = float(coordinates[1])
|
translateY = float(coordinates[1])
|
||||||
mTranslate = mpmath.matrix([[1.0,0.0,translateX],[0.0,1.0,translateY],[0.0,0.0,1.0]])
|
mTranslate = matrix([[1.0,0.0,translateX],[0.0,1.0,translateY],[0.0,0.0,1.0]])
|
||||||
#set up the scale matrix (handles mirror X)
|
#set up the scale matrix (handles mirror X)
|
||||||
scaleX = 1.0
|
scaleX = 1.0
|
||||||
if(transFlags[0]):
|
if(transFlags[0]):
|
||||||
scaleY = -1.0
|
scaleY = -1.0
|
||||||
else:
|
else:
|
||||||
scaleY = 1.0
|
scaleY = 1.0
|
||||||
mScale = mpmath.matrix([[scaleX,0.0,0.0],[0.0,scaleY,0.0],[0.0,0.0,1.0]])
|
mScale = matrix([[scaleX,0.0,0.0],[0.0,scaleY,0.0],[0.0,0.0,1.0]])
|
||||||
|
|
||||||
#we need to keep track of all transforms in the hierarchy
|
#we need to keep track of all transforms in the hierarchy
|
||||||
#when we add an element to the xy tree, we apply all transforms from the bottom up
|
#when we add an element to the xy tree, we apply all transforms from the bottom up
|
||||||
|
|
@ -197,7 +198,7 @@ class VlsiLayout:
|
||||||
transFlags = sref.transFlags,
|
transFlags = sref.transFlags,
|
||||||
coordinates = sref.coordinates)
|
coordinates = sref.coordinates)
|
||||||
# else:
|
# else:
|
||||||
# print "WARNING: via encountered, ignoring:", sref.sName
|
# print("WARNING: via encountered, ignoring:", sref.sName)
|
||||||
#MUST HANDLE AREFs HERE AS WELL
|
#MUST HANDLE AREFs HERE AS WELL
|
||||||
#when we return, drop the last transform from the transformPath
|
#when we return, drop the last transform from the transformPath
|
||||||
del transformPath[-1]
|
del transformPath[-1]
|
||||||
|
|
@ -210,10 +211,10 @@ class VlsiLayout:
|
||||||
|
|
||||||
def populateCoordinateMap(self):
|
def populateCoordinateMap(self):
|
||||||
def addToXyTree(startingStructureName = None,transformPath = None):
|
def addToXyTree(startingStructureName = None,transformPath = None):
|
||||||
#print"populateCoordinateMap"
|
#print("populateCoordinateMap")
|
||||||
uVector = mpmath.matrix([1.0,0.0,0.0]) #start with normal basis vectors
|
uVector = matrix([1.0,0.0,0.0]).transpose() #start with normal basis vectors
|
||||||
vVector = mpmath.matrix([0.0,1.0,0.0])
|
vVector = matrix([0.0,1.0,0.0]).transpose()
|
||||||
origin = mpmath.matrix([0.0,0.0,1.0]) #and an origin (Z component is 1.0 to indicate position instead of vector)
|
origin = matrix([0.0,0.0,1.0]).transpose() #and an origin (Z component is 1.0 to indicate position instead of vector)
|
||||||
#make a copy of all the transforms and reverse it
|
#make a copy of all the transforms and reverse it
|
||||||
reverseTransformPath = transformPath[:]
|
reverseTransformPath = transformPath[:]
|
||||||
if len(reverseTransformPath) > 1:
|
if len(reverseTransformPath) > 1:
|
||||||
|
|
@ -245,7 +246,7 @@ class VlsiLayout:
|
||||||
#userUnitsPerMicron = userUnit / 1e-6
|
#userUnitsPerMicron = userUnit / 1e-6
|
||||||
userUnitsPerMicron = userUnit / (userUnit)
|
userUnitsPerMicron = userUnit / (userUnit)
|
||||||
layoutUnitsPerMicron = userUnitsPerMicron / self.units[0]
|
layoutUnitsPerMicron = userUnitsPerMicron / self.units[0]
|
||||||
#print "userUnit:",userUnit,"userUnitsPerMicron",userUnitsPerMicron,"layoutUnitsPerMicron",layoutUnitsPerMicron,[microns,microns*layoutUnitsPerMicron]
|
#print("userUnit:",userUnit,"userUnitsPerMicron",userUnitsPerMicron,"layoutUnitsPerMicron",layoutUnitsPerMicron,[microns,microns*layoutUnitsPerMicron])
|
||||||
return round(microns*layoutUnitsPerMicron,0)
|
return round(microns*layoutUnitsPerMicron,0)
|
||||||
|
|
||||||
def changeRoot(self,newRoot, create=False):
|
def changeRoot(self,newRoot, create=False):
|
||||||
|
|
@ -259,7 +260,7 @@ class VlsiLayout:
|
||||||
# Determine if newRoot exists
|
# Determine if newRoot exists
|
||||||
# layoutToAdd (default) or nameOfLayout
|
# layoutToAdd (default) or nameOfLayout
|
||||||
if (newRoot == 0 | ((newRoot not in self.structures) & ~create)):
|
if (newRoot == 0 | ((newRoot not in self.structures) & ~create)):
|
||||||
print "ERROR: vlsiLayout.changeRoot: Name of new root [%s] not found and create flag is false"%newRoot
|
print("ERROR: vlsiLayout.changeRoot: Name of new root [%s] not found and create flag is false"%newRoot)
|
||||||
exit(1)
|
exit(1)
|
||||||
else:
|
else:
|
||||||
if ((newRoot not in self.structures) & create):
|
if ((newRoot not in self.structures) & create):
|
||||||
|
|
@ -308,13 +309,13 @@ class VlsiLayout:
|
||||||
self.layerNumbersInUse += [layerNumber]
|
self.layerNumbersInUse += [layerNumber]
|
||||||
#Also, check if the user units / microns is the same as this Layout
|
#Also, check if the user units / microns is the same as this Layout
|
||||||
#if (layoutToAdd.units != self.units):
|
#if (layoutToAdd.units != self.units):
|
||||||
#print "WARNING: VlsiLayout: Units from design to be added do not match target Layout"
|
#print("WARNING: VlsiLayout: Units from design to be added do not match target Layout")
|
||||||
|
|
||||||
# if debug: print "DEBUG: vlsilayout: Using %d layers"
|
# if debug: print("DEBUG: vlsilayout: Using %d layers")
|
||||||
|
|
||||||
# If we can't find the structure, error
|
# If we can't find the structure, error
|
||||||
#if StructureFound == False:
|
#if StructureFound == False:
|
||||||
#print "ERROR: vlsiLayout.addInstance: [%s] Name not found in local structures, "%(nameOfLayout)
|
#print("ERROR: vlsiLayout.addInstance: [%s] Name not found in local structures, "%(nameOfLayout))
|
||||||
#return #FIXME: remove!
|
#return #FIXME: remove!
|
||||||
#exit(1)
|
#exit(1)
|
||||||
|
|
||||||
|
|
@ -353,10 +354,10 @@ class VlsiLayout:
|
||||||
Method to add a box to a layout
|
Method to add a box to a layout
|
||||||
"""
|
"""
|
||||||
offsetInLayoutUnits = (self.userUnits(offsetInMicrons[0]),self.userUnits(offsetInMicrons[1]))
|
offsetInLayoutUnits = (self.userUnits(offsetInMicrons[0]),self.userUnits(offsetInMicrons[1]))
|
||||||
#print "addBox:offsetInLayoutUnits",offsetInLayoutUnits
|
#print("addBox:offsetInLayoutUnits",offsetInLayoutUnits)
|
||||||
widthInLayoutUnits = self.userUnits(width)
|
widthInLayoutUnits = self.userUnits(width)
|
||||||
heightInLayoutUnits = self.userUnits(height)
|
heightInLayoutUnits = self.userUnits(height)
|
||||||
#print "offsetInLayoutUnits",widthInLayoutUnits,"heightInLayoutUnits",heightInLayoutUnits
|
#print("offsetInLayoutUnits",widthInLayoutUnits,"heightInLayoutUnits",heightInLayoutUnits)
|
||||||
if not center:
|
if not center:
|
||||||
coordinates=[offsetInLayoutUnits,
|
coordinates=[offsetInLayoutUnits,
|
||||||
(offsetInLayoutUnits[0]+widthInLayoutUnits,offsetInLayoutUnits[1]),
|
(offsetInLayoutUnits[0]+widthInLayoutUnits,offsetInLayoutUnits[1]),
|
||||||
|
|
@ -522,7 +523,7 @@ class VlsiLayout:
|
||||||
heightInBlocks = int(coverageHeight/effectiveBlock)
|
heightInBlocks = int(coverageHeight/effectiveBlock)
|
||||||
passFailRecord = []
|
passFailRecord = []
|
||||||
|
|
||||||
print "Filling layer:",layerToFill
|
print("Filling layer:",layerToFill)
|
||||||
def isThisBlockOk(startingStructureName,coordinates,rotateAngle=None):
|
def isThisBlockOk(startingStructureName,coordinates,rotateAngle=None):
|
||||||
#go through every boundary and check
|
#go through every boundary and check
|
||||||
for boundary in self.structures[startingStructureName].boundaries:
|
for boundary in self.structures[startingStructureName].boundaries:
|
||||||
|
|
@ -568,7 +569,7 @@ class VlsiLayout:
|
||||||
#if its bad, this global tempPassFail will be false
|
#if its bad, this global tempPassFail will be false
|
||||||
#if true, we can add the block
|
#if true, we can add the block
|
||||||
passFailRecord+=[self.tempPassFail]
|
passFailRecord+=[self.tempPassFail]
|
||||||
print "Percent Complete:"+str(percentDone)
|
print("Percent Complete:"+str(percentDone))
|
||||||
|
|
||||||
|
|
||||||
passFailIndex=0
|
passFailIndex=0
|
||||||
|
|
@ -579,7 +580,7 @@ class VlsiLayout:
|
||||||
if passFailRecord[passFailIndex]:
|
if passFailRecord[passFailIndex]:
|
||||||
self.addBox(layerToFill, (blockX,blockY), width=blockSize, height=blockSize)
|
self.addBox(layerToFill, (blockX,blockY), width=blockSize, height=blockSize)
|
||||||
passFailIndex+=1
|
passFailIndex+=1
|
||||||
print "Done\n\n"
|
print("Done\n\n")
|
||||||
|
|
||||||
def getLayoutBorder(self,borderlayer):
|
def getLayoutBorder(self,borderlayer):
|
||||||
for boundary in self.structures[self.rootStructureName].boundaries:
|
for boundary in self.structures[self.rootStructureName].boundaries:
|
||||||
|
|
@ -591,7 +592,7 @@ class VlsiLayout:
|
||||||
cellSize=[right_top[0]-left_bottom[0],right_top[1]-left_bottom[1]]
|
cellSize=[right_top[0]-left_bottom[0],right_top[1]-left_bottom[1]]
|
||||||
cellSizeMicron=[cellSize[0]*self.units[0],cellSize[1]*self.units[0]]
|
cellSizeMicron=[cellSize[0]*self.units[0],cellSize[1]*self.units[0]]
|
||||||
if not(cellSizeMicron):
|
if not(cellSizeMicron):
|
||||||
print "Error: "+str(self.rootStructureName)+".cell_size information not found yet"
|
print("Error: "+str(self.rootStructureName)+".cell_size information not found yet")
|
||||||
return cellSizeMicron
|
return cellSizeMicron
|
||||||
|
|
||||||
def measureSize(self,startStructure):
|
def measureSize(self,startStructure):
|
||||||
|
|
@ -700,7 +701,7 @@ class VlsiLayout:
|
||||||
debug.warning("Did not find pin on layer {0} at coordinate {1}".format(layer, coordinate))
|
debug.warning("Did not find pin on layer {0} at coordinate {1}".format(layer, coordinate))
|
||||||
|
|
||||||
# sort the boundaries, return the max area pin boundary
|
# sort the boundaries, return the max area pin boundary
|
||||||
pin_boundaries.sort(cmpBoundaryAreas,reverse=True)
|
pin_boundaries.sort(key=boundaryArea,reverse=True)
|
||||||
pin_boundary=pin_boundaries[0]
|
pin_boundary=pin_boundaries[0]
|
||||||
|
|
||||||
# Convert to USER units
|
# Convert to USER units
|
||||||
|
|
@ -743,7 +744,8 @@ class VlsiLayout:
|
||||||
shape_list=[]
|
shape_list=[]
|
||||||
for label in label_list:
|
for label in label_list:
|
||||||
(label_coordinate,label_layer)=label
|
(label_coordinate,label_layer)=label
|
||||||
shape_list.append(self.getPinShapeByDBLocLayer(label_coordinate, label_layer))
|
shape = self.getPinShapeByDBLocLayer(label_coordinate, label_layer)
|
||||||
|
shape_list.append(shape)
|
||||||
return shape_list
|
return shape_list
|
||||||
|
|
||||||
def getAllPinShapesByLabel(self,label_name):
|
def getAllPinShapesByLabel(self,label_name):
|
||||||
|
|
@ -797,23 +799,23 @@ class VlsiLayout:
|
||||||
# Rectangle is [leftx, bottomy, rightx, topy].
|
# Rectangle is [leftx, bottomy, rightx, topy].
|
||||||
boundaryRect=[left_bottom[0],left_bottom[1],right_top[0],right_top[1]]
|
boundaryRect=[left_bottom[0],left_bottom[1],right_top[0],right_top[1]]
|
||||||
boundaryRect=self.transformRectangle(boundaryRect,structureuVector,structurevVector)
|
boundaryRect=self.transformRectangle(boundaryRect,structureuVector,structurevVector)
|
||||||
boundaryRect=[boundaryRect[0]+structureOrigin[0],boundaryRect[1]+structureOrigin[1],
|
boundaryRect=[boundaryRect[0]+structureOrigin[0].item(),boundaryRect[1]+structureOrigin[1].item(),
|
||||||
boundaryRect[2]+structureOrigin[0],boundaryRect[3]+structureOrigin[1]]
|
boundaryRect[2]+structureOrigin[0].item(),boundaryRect[3]+structureOrigin[1].item()]
|
||||||
|
|
||||||
if self.labelInRectangle(coordinates,boundaryRect):
|
if self.labelInRectangle(coordinates,boundaryRect):
|
||||||
boundaries.append(boundaryRect)
|
boundaries.append(boundaryRect)
|
||||||
|
|
||||||
return boundaries
|
return boundaries
|
||||||
|
|
||||||
def transformRectangle(self,orignalRectangle,uVector,vVector):
|
def transformRectangle(self,originalRectangle,uVector,vVector):
|
||||||
"""
|
"""
|
||||||
Transforms the four coordinates of a rectangle in space
|
Transforms the four coordinates of a rectangle in space
|
||||||
and recomputes the left, bottom, right, top values.
|
and recomputes the left, bottom, right, top values.
|
||||||
"""
|
"""
|
||||||
leftBottom=mpmath.matrix([orignalRectangle[0],orignalRectangle[1]])
|
leftBottom=[originalRectangle[0],originalRectangle[1]]
|
||||||
leftBottom=self.transformCoordinate(leftBottom,uVector,vVector)
|
leftBottom=self.transformCoordinate(leftBottom,uVector,vVector)
|
||||||
|
|
||||||
rightTop=mpmath.matrix([orignalRectangle[2],orignalRectangle[3]])
|
rightTop=[originalRectangle[2],originalRectangle[3]]
|
||||||
rightTop=self.transformCoordinate(rightTop,uVector,vVector)
|
rightTop=self.transformCoordinate(rightTop,uVector,vVector)
|
||||||
|
|
||||||
left=min(leftBottom[0],rightTop[0])
|
left=min(leftBottom[0],rightTop[0])
|
||||||
|
|
@ -821,14 +823,15 @@ class VlsiLayout:
|
||||||
right=max(leftBottom[0],rightTop[0])
|
right=max(leftBottom[0],rightTop[0])
|
||||||
top=max(leftBottom[1],rightTop[1])
|
top=max(leftBottom[1],rightTop[1])
|
||||||
|
|
||||||
return [left,bottom,right,top]
|
newRectangle = [left,bottom,right,top]
|
||||||
|
return newRectangle
|
||||||
|
|
||||||
def transformCoordinate(self,coordinate,uVector,vVector):
|
def transformCoordinate(self,coordinate,uVector,vVector):
|
||||||
"""
|
"""
|
||||||
Rotate a coordinate in space.
|
Rotate a coordinate in space.
|
||||||
"""
|
"""
|
||||||
x=coordinate[0]*uVector[0]+coordinate[1]*uVector[1]
|
x=coordinate[0]*uVector[0].item()+coordinate[1]*uVector[1].item()
|
||||||
y=coordinate[1]*vVector[1]+coordinate[0]*vVector[0]
|
y=coordinate[1]*vVector[1].item()+coordinate[0]*vVector[0].item()
|
||||||
transformCoordinate=[x,y]
|
transformCoordinate=[x,y]
|
||||||
|
|
||||||
return transformCoordinate
|
return transformCoordinate
|
||||||
|
|
@ -845,18 +848,12 @@ class VlsiLayout:
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def cmpBoundaryAreas(A,B):
|
def boundaryArea(A):
|
||||||
"""
|
"""
|
||||||
Compares two rectangles and return true if Area(A)>Area(B).
|
Returns boundary area for sorting.
|
||||||
"""
|
"""
|
||||||
area_A=(A[2]-A[0])*(A[3]-A[1])
|
area_A=(A[2]-A[0])*(A[3]-A[1])
|
||||||
area_B=(B[2]-B[0])*(B[3]-B[1])
|
return area_A
|
||||||
if area_A>area_B:
|
|
||||||
return 1
|
|
||||||
elif area_A==area_B:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,386 +0,0 @@
|
||||||
__version__ = '0.14'
|
|
||||||
|
|
||||||
from usertools import monitor, timing
|
|
||||||
|
|
||||||
from ctx_fp import FPContext
|
|
||||||
from ctx_mp import MPContext
|
|
||||||
|
|
||||||
fp = FPContext()
|
|
||||||
mp = MPContext()
|
|
||||||
|
|
||||||
fp._mp = mp
|
|
||||||
mp._mp = mp
|
|
||||||
mp._fp = fp
|
|
||||||
fp._fp = fp
|
|
||||||
|
|
||||||
# XXX: extremely bad pickle hack
|
|
||||||
import ctx_mp as _ctx_mp
|
|
||||||
_ctx_mp._mpf_module.mpf = mp.mpf
|
|
||||||
_ctx_mp._mpf_module.mpc = mp.mpc
|
|
||||||
|
|
||||||
make_mpf = mp.make_mpf
|
|
||||||
make_mpc = mp.make_mpc
|
|
||||||
|
|
||||||
extraprec = mp.extraprec
|
|
||||||
extradps = mp.extradps
|
|
||||||
workprec = mp.workprec
|
|
||||||
workdps = mp.workdps
|
|
||||||
|
|
||||||
mag = mp.mag
|
|
||||||
|
|
||||||
bernfrac = mp.bernfrac
|
|
||||||
|
|
||||||
jdn = mp.jdn
|
|
||||||
jsn = mp.jsn
|
|
||||||
jcn = mp.jcn
|
|
||||||
jtheta = mp.jtheta
|
|
||||||
calculate_nome = mp.calculate_nome
|
|
||||||
|
|
||||||
nint_distance = mp.nint_distance
|
|
||||||
|
|
||||||
plot = mp.plot
|
|
||||||
cplot = mp.cplot
|
|
||||||
splot = mp.splot
|
|
||||||
|
|
||||||
odefun = mp.odefun
|
|
||||||
|
|
||||||
jacobian = mp.jacobian
|
|
||||||
findroot = mp.findroot
|
|
||||||
multiplicity = mp.multiplicity
|
|
||||||
|
|
||||||
isinf = mp.isinf
|
|
||||||
isnan = mp.isnan
|
|
||||||
isint = mp.isint
|
|
||||||
almosteq = mp.almosteq
|
|
||||||
nan = mp.nan
|
|
||||||
rand = mp.rand
|
|
||||||
|
|
||||||
absmin = mp.absmin
|
|
||||||
absmax = mp.absmax
|
|
||||||
|
|
||||||
fraction = mp.fraction
|
|
||||||
|
|
||||||
linspace = mp.linspace
|
|
||||||
arange = mp.arange
|
|
||||||
|
|
||||||
mpmathify = convert = mp.convert
|
|
||||||
mpc = mp.mpc
|
|
||||||
mpi = mp.mpi
|
|
||||||
|
|
||||||
nstr = mp.nstr
|
|
||||||
nprint = mp.nprint
|
|
||||||
chop = mp.chop
|
|
||||||
|
|
||||||
fneg = mp.fneg
|
|
||||||
fadd = mp.fadd
|
|
||||||
fsub = mp.fsub
|
|
||||||
fmul = mp.fmul
|
|
||||||
fdiv = mp.fdiv
|
|
||||||
fprod = mp.fprod
|
|
||||||
|
|
||||||
quad = mp.quad
|
|
||||||
quadgl = mp.quadgl
|
|
||||||
quadts = mp.quadts
|
|
||||||
quadosc = mp.quadosc
|
|
||||||
|
|
||||||
pslq = mp.pslq
|
|
||||||
identify = mp.identify
|
|
||||||
findpoly = mp.findpoly
|
|
||||||
|
|
||||||
richardson = mp.richardson
|
|
||||||
shanks = mp.shanks
|
|
||||||
nsum = mp.nsum
|
|
||||||
nprod = mp.nprod
|
|
||||||
diff = mp.diff
|
|
||||||
diffs = mp.diffs
|
|
||||||
diffun = mp.diffun
|
|
||||||
differint = mp.differint
|
|
||||||
taylor = mp.taylor
|
|
||||||
pade = mp.pade
|
|
||||||
polyval = mp.polyval
|
|
||||||
polyroots = mp.polyroots
|
|
||||||
fourier = mp.fourier
|
|
||||||
fourierval = mp.fourierval
|
|
||||||
sumem = mp.sumem
|
|
||||||
chebyfit = mp.chebyfit
|
|
||||||
limit = mp.limit
|
|
||||||
|
|
||||||
matrix = mp.matrix
|
|
||||||
eye = mp.eye
|
|
||||||
diag = mp.diag
|
|
||||||
zeros = mp.zeros
|
|
||||||
ones = mp.ones
|
|
||||||
hilbert = mp.hilbert
|
|
||||||
randmatrix = mp.randmatrix
|
|
||||||
swap_row = mp.swap_row
|
|
||||||
extend = mp.extend
|
|
||||||
norm = mp.norm
|
|
||||||
mnorm = mp.mnorm
|
|
||||||
|
|
||||||
lu_solve = mp.lu_solve
|
|
||||||
lu = mp.lu
|
|
||||||
unitvector = mp.unitvector
|
|
||||||
inverse = mp.inverse
|
|
||||||
residual = mp.residual
|
|
||||||
qr_solve = mp.qr_solve
|
|
||||||
cholesky = mp.cholesky
|
|
||||||
cholesky_solve = mp.cholesky_solve
|
|
||||||
det = mp.det
|
|
||||||
cond = mp.cond
|
|
||||||
|
|
||||||
expm = mp.expm
|
|
||||||
sqrtm = mp.sqrtm
|
|
||||||
powm = mp.powm
|
|
||||||
logm = mp.logm
|
|
||||||
sinm = mp.sinm
|
|
||||||
cosm = mp.cosm
|
|
||||||
|
|
||||||
mpf = mp.mpf
|
|
||||||
j = mp.j
|
|
||||||
exp = mp.exp
|
|
||||||
expj = mp.expj
|
|
||||||
expjpi = mp.expjpi
|
|
||||||
ln = mp.ln
|
|
||||||
im = mp.im
|
|
||||||
re = mp.re
|
|
||||||
inf = mp.inf
|
|
||||||
ninf = mp.ninf
|
|
||||||
sign = mp.sign
|
|
||||||
|
|
||||||
eps = mp.eps
|
|
||||||
pi = mp.pi
|
|
||||||
ln2 = mp.ln2
|
|
||||||
ln10 = mp.ln10
|
|
||||||
phi = mp.phi
|
|
||||||
e = mp.e
|
|
||||||
euler = mp.euler
|
|
||||||
catalan = mp.catalan
|
|
||||||
khinchin = mp.khinchin
|
|
||||||
glaisher = mp.glaisher
|
|
||||||
apery = mp.apery
|
|
||||||
degree = mp.degree
|
|
||||||
twinprime = mp.twinprime
|
|
||||||
mertens = mp.mertens
|
|
||||||
|
|
||||||
ldexp = mp.ldexp
|
|
||||||
frexp = mp.frexp
|
|
||||||
|
|
||||||
fsum = mp.fsum
|
|
||||||
fdot = mp.fdot
|
|
||||||
|
|
||||||
sqrt = mp.sqrt
|
|
||||||
cbrt = mp.cbrt
|
|
||||||
exp = mp.exp
|
|
||||||
ln = mp.ln
|
|
||||||
log = mp.log
|
|
||||||
log10 = mp.log10
|
|
||||||
power = mp.power
|
|
||||||
cos = mp.cos
|
|
||||||
sin = mp.sin
|
|
||||||
tan = mp.tan
|
|
||||||
cosh = mp.cosh
|
|
||||||
sinh = mp.sinh
|
|
||||||
tanh = mp.tanh
|
|
||||||
acos = mp.acos
|
|
||||||
asin = mp.asin
|
|
||||||
atan = mp.atan
|
|
||||||
asinh = mp.asinh
|
|
||||||
acosh = mp.acosh
|
|
||||||
atanh = mp.atanh
|
|
||||||
sec = mp.sec
|
|
||||||
csc = mp.csc
|
|
||||||
cot = mp.cot
|
|
||||||
sech = mp.sech
|
|
||||||
csch = mp.csch
|
|
||||||
coth = mp.coth
|
|
||||||
asec = mp.asec
|
|
||||||
acsc = mp.acsc
|
|
||||||
acot = mp.acot
|
|
||||||
asech = mp.asech
|
|
||||||
acsch = mp.acsch
|
|
||||||
acoth = mp.acoth
|
|
||||||
cospi = mp.cospi
|
|
||||||
sinpi = mp.sinpi
|
|
||||||
sinc = mp.sinc
|
|
||||||
sincpi = mp.sincpi
|
|
||||||
fabs = mp.fabs
|
|
||||||
re = mp.re
|
|
||||||
im = mp.im
|
|
||||||
conj = mp.conj
|
|
||||||
floor = mp.floor
|
|
||||||
ceil = mp.ceil
|
|
||||||
root = mp.root
|
|
||||||
nthroot = mp.nthroot
|
|
||||||
hypot = mp.hypot
|
|
||||||
modf = mp.modf
|
|
||||||
ldexp = mp.ldexp
|
|
||||||
frexp = mp.frexp
|
|
||||||
sign = mp.sign
|
|
||||||
arg = mp.arg
|
|
||||||
phase = mp.phase
|
|
||||||
polar = mp.polar
|
|
||||||
rect = mp.rect
|
|
||||||
degrees = mp.degrees
|
|
||||||
radians = mp.radians
|
|
||||||
atan2 = mp.atan2
|
|
||||||
fib = mp.fib
|
|
||||||
fibonacci = mp.fibonacci
|
|
||||||
lambertw = mp.lambertw
|
|
||||||
zeta = mp.zeta
|
|
||||||
altzeta = mp.altzeta
|
|
||||||
gamma = mp.gamma
|
|
||||||
factorial = mp.factorial
|
|
||||||
fac = mp.fac
|
|
||||||
fac2 = mp.fac2
|
|
||||||
beta = mp.beta
|
|
||||||
betainc = mp.betainc
|
|
||||||
psi = mp.psi
|
|
||||||
#psi0 = mp.psi0
|
|
||||||
#psi1 = mp.psi1
|
|
||||||
#psi2 = mp.psi2
|
|
||||||
#psi3 = mp.psi3
|
|
||||||
polygamma = mp.polygamma
|
|
||||||
digamma = mp.digamma
|
|
||||||
#trigamma = mp.trigamma
|
|
||||||
#tetragamma = mp.tetragamma
|
|
||||||
#pentagamma = mp.pentagamma
|
|
||||||
harmonic = mp.harmonic
|
|
||||||
bernoulli = mp.bernoulli
|
|
||||||
bernfrac = mp.bernfrac
|
|
||||||
stieltjes = mp.stieltjes
|
|
||||||
hurwitz = mp.hurwitz
|
|
||||||
dirichlet = mp.dirichlet
|
|
||||||
bernpoly = mp.bernpoly
|
|
||||||
eulerpoly = mp.eulerpoly
|
|
||||||
eulernum = mp.eulernum
|
|
||||||
polylog = mp.polylog
|
|
||||||
clsin = mp.clsin
|
|
||||||
clcos = mp.clcos
|
|
||||||
gammainc = mp.gammainc
|
|
||||||
gammaprod = mp.gammaprod
|
|
||||||
binomial = mp.binomial
|
|
||||||
rf = mp.rf
|
|
||||||
ff = mp.ff
|
|
||||||
hyper = mp.hyper
|
|
||||||
hyp0f1 = mp.hyp0f1
|
|
||||||
hyp1f1 = mp.hyp1f1
|
|
||||||
hyp1f2 = mp.hyp1f2
|
|
||||||
hyp2f1 = mp.hyp2f1
|
|
||||||
hyp2f2 = mp.hyp2f2
|
|
||||||
hyp2f0 = mp.hyp2f0
|
|
||||||
hyp2f3 = mp.hyp2f3
|
|
||||||
hyp3f2 = mp.hyp3f2
|
|
||||||
hyperu = mp.hyperu
|
|
||||||
hypercomb = mp.hypercomb
|
|
||||||
meijerg = mp.meijerg
|
|
||||||
appellf1 = mp.appellf1
|
|
||||||
erf = mp.erf
|
|
||||||
erfc = mp.erfc
|
|
||||||
erfi = mp.erfi
|
|
||||||
erfinv = mp.erfinv
|
|
||||||
npdf = mp.npdf
|
|
||||||
ncdf = mp.ncdf
|
|
||||||
expint = mp.expint
|
|
||||||
e1 = mp.e1
|
|
||||||
ei = mp.ei
|
|
||||||
li = mp.li
|
|
||||||
ci = mp.ci
|
|
||||||
si = mp.si
|
|
||||||
chi = mp.chi
|
|
||||||
shi = mp.shi
|
|
||||||
fresnels = mp.fresnels
|
|
||||||
fresnelc = mp.fresnelc
|
|
||||||
airyai = mp.airyai
|
|
||||||
airybi = mp.airybi
|
|
||||||
ellipe = mp.ellipe
|
|
||||||
ellipk = mp.ellipk
|
|
||||||
agm = mp.agm
|
|
||||||
jacobi = mp.jacobi
|
|
||||||
chebyt = mp.chebyt
|
|
||||||
chebyu = mp.chebyu
|
|
||||||
legendre = mp.legendre
|
|
||||||
legenp = mp.legenp
|
|
||||||
legenq = mp.legenq
|
|
||||||
hermite = mp.hermite
|
|
||||||
gegenbauer = mp.gegenbauer
|
|
||||||
laguerre = mp.laguerre
|
|
||||||
spherharm = mp.spherharm
|
|
||||||
besselj = mp.besselj
|
|
||||||
j0 = mp.j0
|
|
||||||
j1 = mp.j1
|
|
||||||
besseli = mp.besseli
|
|
||||||
bessely = mp.bessely
|
|
||||||
besselk = mp.besselk
|
|
||||||
hankel1 = mp.hankel1
|
|
||||||
hankel2 = mp.hankel2
|
|
||||||
struveh = mp.struveh
|
|
||||||
struvel = mp.struvel
|
|
||||||
whitm = mp.whitm
|
|
||||||
whitw = mp.whitw
|
|
||||||
ber = mp.ber
|
|
||||||
bei = mp.bei
|
|
||||||
ker = mp.ker
|
|
||||||
kei = mp.kei
|
|
||||||
coulombc = mp.coulombc
|
|
||||||
coulombf = mp.coulombf
|
|
||||||
coulombg = mp.coulombg
|
|
||||||
lambertw = mp.lambertw
|
|
||||||
barnesg = mp.barnesg
|
|
||||||
superfac = mp.superfac
|
|
||||||
hyperfac = mp.hyperfac
|
|
||||||
loggamma = mp.loggamma
|
|
||||||
siegeltheta = mp.siegeltheta
|
|
||||||
siegelz = mp.siegelz
|
|
||||||
grampoint = mp.grampoint
|
|
||||||
zetazero = mp.zetazero
|
|
||||||
riemannr = mp.riemannr
|
|
||||||
primepi = mp.primepi
|
|
||||||
primepi2 = mp.primepi2
|
|
||||||
primezeta = mp.primezeta
|
|
||||||
bell = mp.bell
|
|
||||||
polyexp = mp.polyexp
|
|
||||||
expm1 = mp.expm1
|
|
||||||
powm1 = mp.powm1
|
|
||||||
unitroots = mp.unitroots
|
|
||||||
cyclotomic = mp.cyclotomic
|
|
||||||
|
|
||||||
|
|
||||||
# be careful when changing this name, don't use test*!
|
|
||||||
def runtests():
|
|
||||||
"""
|
|
||||||
Run all mpmath tests and print output.
|
|
||||||
"""
|
|
||||||
import os.path
|
|
||||||
from inspect import getsourcefile
|
|
||||||
import tests.runtests as tests
|
|
||||||
testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
|
|
||||||
importdir = os.path.abspath(testdir + '/../..')
|
|
||||||
tests.testit(importdir, testdir)
|
|
||||||
|
|
||||||
def doctests():
|
|
||||||
try:
|
|
||||||
import psyco; psyco.full()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
import sys
|
|
||||||
from timeit import default_timer as clock
|
|
||||||
filter = []
|
|
||||||
for i, arg in enumerate(sys.argv):
|
|
||||||
if '__init__.py' in arg:
|
|
||||||
filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
|
|
||||||
break
|
|
||||||
import doctest
|
|
||||||
globs = globals().copy()
|
|
||||||
for obj in globs: #sorted(globs.keys()):
|
|
||||||
if filter:
|
|
||||||
if not sum([pat in obj for pat in filter]):
|
|
||||||
continue
|
|
||||||
print obj,
|
|
||||||
t1 = clock()
|
|
||||||
doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
|
|
||||||
t2 = clock()
|
|
||||||
print round(t2-t1, 3)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
doctests()
|
|
||||||
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
import calculus
|
|
||||||
# XXX: hack to set methods
|
|
||||||
import approximation
|
|
||||||
import differentiation
|
|
||||||
import extrapolation
|
|
||||||
import polynomials
|
|
||||||
|
|
@ -1,246 +0,0 @@
|
||||||
from calculus import defun
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
# Approximation methods #
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
|
|
||||||
# The Chebyshev approximation formula is given at:
|
|
||||||
# http://mathworld.wolfram.com/ChebyshevApproximationFormula.html
|
|
||||||
|
|
||||||
# The only major changes in the following code is that we return the
|
|
||||||
# expanded polynomial coefficients instead of Chebyshev coefficients,
|
|
||||||
# and that we automatically transform [a,b] -> [-1,1] and back
|
|
||||||
# for convenience.
|
|
||||||
|
|
||||||
# Coefficient in Chebyshev approximation
|
|
||||||
def chebcoeff(ctx,f,a,b,j,N):
|
|
||||||
s = ctx.mpf(0)
|
|
||||||
h = ctx.mpf(0.5)
|
|
||||||
for k in range(1, N+1):
|
|
||||||
t = ctx.cos(ctx.pi*(k-h)/N)
|
|
||||||
s += f(t*(b-a)*h + (b+a)*h) * ctx.cos(ctx.pi*j*(k-h)/N)
|
|
||||||
return 2*s/N
|
|
||||||
|
|
||||||
# Generate Chebyshev polynomials T_n(ax+b) in expanded form
|
|
||||||
def chebT(ctx, a=1, b=0):
|
|
||||||
Tb = [1]
|
|
||||||
yield Tb
|
|
||||||
Ta = [b, a]
|
|
||||||
while 1:
|
|
||||||
yield Ta
|
|
||||||
# Recurrence: T[n+1](ax+b) = 2*(ax+b)*T[n](ax+b) - T[n-1](ax+b)
|
|
||||||
Tmp = [0] + [2*a*t for t in Ta]
|
|
||||||
for i, c in enumerate(Ta): Tmp[i] += 2*b*c
|
|
||||||
for i, c in enumerate(Tb): Tmp[i] -= c
|
|
||||||
Ta, Tb = Tmp, Ta
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def chebyfit(ctx, f, interval, N, error=False):
|
|
||||||
r"""
|
|
||||||
Computes a polynomial of degree `N-1` that approximates the
|
|
||||||
given function `f` on the interval `[a, b]`. With ``error=True``,
|
|
||||||
:func:`chebyfit` also returns an accurate estimate of the
|
|
||||||
maximum absolute error; that is, the maximum value of
|
|
||||||
`|f(x) - P(x)|` for `x \in [a, b]`.
|
|
||||||
|
|
||||||
:func:`chebyfit` uses the Chebyshev approximation formula,
|
|
||||||
which gives a nearly optimal solution: that is, the maximum
|
|
||||||
error of the approximating polynomial is very close to
|
|
||||||
the smallest possible for any polynomial of the same degree.
|
|
||||||
|
|
||||||
Chebyshev approximation is very useful if one needs repeated
|
|
||||||
evaluation of an expensive function, such as function defined
|
|
||||||
implicitly by an integral or a differential equation. (For
|
|
||||||
example, it could be used to turn a slow mpmath function
|
|
||||||
into a fast machine-precision version of the same.)
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Here we use :func:`chebyfit` to generate a low-degree approximation
|
|
||||||
of `f(x) = \cos(x)`, valid on the interval `[1, 2]`::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> poly, err = chebyfit(cos, [1, 2], 5, error=True)
|
|
||||||
>>> nprint(poly)
|
|
||||||
[0.00291682, 0.146166, -0.732491, 0.174141, 0.949553]
|
|
||||||
>>> nprint(err, 12)
|
|
||||||
1.61351758081e-5
|
|
||||||
|
|
||||||
The polynomial can be evaluated using ``polyval``::
|
|
||||||
|
|
||||||
>>> nprint(polyval(poly, 1.6), 12)
|
|
||||||
-0.0291858904138
|
|
||||||
>>> nprint(cos(1.6), 12)
|
|
||||||
-0.0291995223013
|
|
||||||
|
|
||||||
Sampling the true error at 1000 points shows that the error
|
|
||||||
estimate generated by ``chebyfit`` is remarkably good::
|
|
||||||
|
|
||||||
>>> error = lambda x: abs(cos(x) - polyval(poly, x))
|
|
||||||
>>> nprint(max([error(1+n/1000.) for n in range(1000)]), 12)
|
|
||||||
1.61349954245e-5
|
|
||||||
|
|
||||||
**Choice of degree**
|
|
||||||
|
|
||||||
The degree `N` can be set arbitrarily high, to obtain an
|
|
||||||
arbitrarily good approximation. As a rule of thumb, an
|
|
||||||
`N`-term Chebyshev approximation is good to `N/(b-a)` decimal
|
|
||||||
places on a unit interval (although this depends on how
|
|
||||||
well-behaved `f` is). The cost grows accordingly: ``chebyfit``
|
|
||||||
evaluates the function `(N^2)/2` times to compute the
|
|
||||||
coefficients and an additional `N` times to estimate the error.
|
|
||||||
|
|
||||||
**Possible issues**
|
|
||||||
|
|
||||||
One should be careful to use a sufficiently high working
|
|
||||||
precision both when calling ``chebyfit`` and when evaluating
|
|
||||||
the resulting polynomial, as the polynomial is sometimes
|
|
||||||
ill-conditioned. It is for example difficult to reach
|
|
||||||
15-digit accuracy when evaluating the polynomial using
|
|
||||||
machine precision floats, no matter the theoretical
|
|
||||||
accuracy of the polynomial. (The option to return the
|
|
||||||
coefficients in Chebyshev form should be made available
|
|
||||||
in the future.)
|
|
||||||
|
|
||||||
It is important to note the Chebyshev approximation works
|
|
||||||
poorly if `f` is not smooth. A function containing singularities,
|
|
||||||
rapid oscillation, etc can be approximated more effectively by
|
|
||||||
multiplying it by a weight function that cancels out the
|
|
||||||
nonsmooth features, or by dividing the interval into several
|
|
||||||
segments.
|
|
||||||
"""
|
|
||||||
a, b = ctx._as_points(interval)
|
|
||||||
orig = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec = orig + int(N**0.5) + 20
|
|
||||||
c = [chebcoeff(ctx,f,a,b,k,N) for k in range(N)]
|
|
||||||
d = [ctx.zero] * N
|
|
||||||
d[0] = -c[0]/2
|
|
||||||
h = ctx.mpf(0.5)
|
|
||||||
T = chebT(ctx, ctx.mpf(2)/(b-a), ctx.mpf(-1)*(b+a)/(b-a))
|
|
||||||
for k in range(N):
|
|
||||||
Tk = T.next()
|
|
||||||
for i in range(len(Tk)):
|
|
||||||
d[i] += c[k]*Tk[i]
|
|
||||||
d = d[::-1]
|
|
||||||
# Estimate maximum error
|
|
||||||
err = ctx.zero
|
|
||||||
for k in range(N):
|
|
||||||
x = ctx.cos(ctx.pi*k/N) * (b-a)*h + (b+a)*h
|
|
||||||
err = max(err, abs(f(x) - ctx.polyval(d, x)))
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
if error:
|
|
||||||
return d, +err
|
|
||||||
else:
|
|
||||||
return d
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def fourier(ctx, f, interval, N):
|
|
||||||
r"""
|
|
||||||
Computes the Fourier series of degree `N` of the given function
|
|
||||||
on the interval `[a, b]`. More precisely, :func:`fourier` returns
|
|
||||||
two lists `(c, s)` of coefficients (the cosine series and sine
|
|
||||||
series, respectively), such that
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
f(x) \sim \sum_{k=0}^N
|
|
||||||
c_k \cos(k m) + s_k \sin(k m)
|
|
||||||
|
|
||||||
where `m = 2 \pi / (b-a)`.
|
|
||||||
|
|
||||||
Note that many texts define the first coefficient as `2 c_0` instead
|
|
||||||
of `c_0`. The easiest way to evaluate the computed series correctly
|
|
||||||
is to pass it to :func:`fourierval`.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
The function `f(x) = x` has a simple Fourier series on the standard
|
|
||||||
interval `[-\pi, \pi]`. The cosine coefficients are all zero (because
|
|
||||||
the function has odd symmetry), and the sine coefficients are
|
|
||||||
rational numbers::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> c, s = fourier(lambda x: x, [-pi, pi], 5)
|
|
||||||
>>> nprint(c)
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
|
||||||
>>> nprint(s)
|
|
||||||
[0.0, 2.0, -1.0, 0.666667, -0.5, 0.4]
|
|
||||||
|
|
||||||
This computes a Fourier series of a nonsymmetric function on
|
|
||||||
a nonstandard interval::
|
|
||||||
|
|
||||||
>>> I = [-1, 1.5]
|
|
||||||
>>> f = lambda x: x**2 - 4*x + 1
|
|
||||||
>>> cs = fourier(f, I, 4)
|
|
||||||
>>> nprint(cs[0])
|
|
||||||
[0.583333, 1.12479, -1.27552, 0.904708, -0.441296]
|
|
||||||
>>> nprint(cs[1])
|
|
||||||
[0.0, -2.6255, 0.580905, 0.219974, -0.540057]
|
|
||||||
|
|
||||||
It is instructive to plot a function along with its truncated
|
|
||||||
Fourier series::
|
|
||||||
|
|
||||||
>>> plot([f, lambda x: fourierval(cs, I, x)], I) #doctest: +SKIP
|
|
||||||
|
|
||||||
Fourier series generally converge slowly (and may not converge
|
|
||||||
pointwise). For example, if `f(x) = \cosh(x)`, a 10-term Fourier
|
|
||||||
series gives an `L^2` error corresponding to 2-digit accuracy::
|
|
||||||
|
|
||||||
>>> I = [-1, 1]
|
|
||||||
>>> cs = fourier(cosh, I, 9)
|
|
||||||
>>> g = lambda x: (cosh(x) - fourierval(cs, I, x))**2
|
|
||||||
>>> nprint(sqrt(quad(g, I)))
|
|
||||||
0.00467963
|
|
||||||
|
|
||||||
:func:`fourier` uses numerical quadrature. For nonsmooth functions,
|
|
||||||
the accuracy (and speed) can be improved by including all singular
|
|
||||||
points in the interval specification::
|
|
||||||
|
|
||||||
>>> nprint(fourier(abs, [-1, 1], 0), 10)
|
|
||||||
([0.5000441648], [0.0])
|
|
||||||
>>> nprint(fourier(abs, [-1, 0, 1], 0), 10)
|
|
||||||
([0.5], [0.0])
|
|
||||||
|
|
||||||
"""
|
|
||||||
interval = ctx._as_points(interval)
|
|
||||||
a = interval[0]
|
|
||||||
b = interval[-1]
|
|
||||||
L = b-a
|
|
||||||
cos_series = []
|
|
||||||
sin_series = []
|
|
||||||
cutoff = ctx.eps*10
|
|
||||||
for n in xrange(N+1):
|
|
||||||
m = 2*n*ctx.pi/L
|
|
||||||
an = 2*ctx.quadgl(lambda t: f(t)*ctx.cos(m*t), interval)/L
|
|
||||||
bn = 2*ctx.quadgl(lambda t: f(t)*ctx.sin(m*t), interval)/L
|
|
||||||
if n == 0:
|
|
||||||
an /= 2
|
|
||||||
if abs(an) < cutoff: an = ctx.zero
|
|
||||||
if abs(bn) < cutoff: bn = ctx.zero
|
|
||||||
cos_series.append(an)
|
|
||||||
sin_series.append(bn)
|
|
||||||
return cos_series, sin_series
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def fourierval(ctx, series, interval, x):
|
|
||||||
"""
|
|
||||||
Evaluates a Fourier series (in the format computed by
|
|
||||||
by :func:`fourier` for the given interval) at the point `x`.
|
|
||||||
|
|
||||||
The series should be a pair `(c, s)` where `c` is the
|
|
||||||
cosine series and `s` is the sine series. The two lists
|
|
||||||
need not have the same length.
|
|
||||||
"""
|
|
||||||
cs, ss = series
|
|
||||||
ab = ctx._as_points(interval)
|
|
||||||
a = interval[0]
|
|
||||||
b = interval[-1]
|
|
||||||
m = 2*ctx.pi/(ab[-1]-ab[0])
|
|
||||||
s = ctx.zero
|
|
||||||
s += ctx.fsum(cs[n]*ctx.cos(m*n*x) for n in xrange(len(cs)) if cs[n])
|
|
||||||
s += ctx.fsum(ss[n]*ctx.sin(m*n*x) for n in xrange(len(ss)) if ss[n])
|
|
||||||
return s
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
class CalculusMethods(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def defun(f):
|
|
||||||
setattr(CalculusMethods, f.__name__, f)
|
|
||||||
|
|
@ -1,438 +0,0 @@
|
||||||
from calculus import defun
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
# Differentiation #
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def difference_delta(ctx, s, n):
|
|
||||||
r"""
|
|
||||||
Given a sequence `(s_k)` containing at least `n+1` items, returns the
|
|
||||||
`n`-th forward difference,
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\Delta^n = \sum_{k=0}^{\infty} (-1)^{k+n} {n \choose k} s_k.
|
|
||||||
"""
|
|
||||||
n = int(n)
|
|
||||||
d = ctx.zero
|
|
||||||
b = (-1) ** (n & 1)
|
|
||||||
for k in xrange(n+1):
|
|
||||||
d += b * s[k]
|
|
||||||
b = (b * (k-n)) // (k+1)
|
|
||||||
return d
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def diff(ctx, f, x, n=1, method='step', scale=1, direction=0):
|
|
||||||
r"""
|
|
||||||
Numerically computes the derivative of `f`, `f'(x)`. Optionally,
|
|
||||||
computes the `n`-th derivative, `f^{(n)}(x)`, for any order `n`.
|
|
||||||
|
|
||||||
**Basic examples**
|
|
||||||
|
|
||||||
Derivatives of a simple function::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> diff(lambda x: x**2 + x, 1.0)
|
|
||||||
3.0
|
|
||||||
>>> diff(lambda x: x**2 + x, 1.0, 2)
|
|
||||||
2.0
|
|
||||||
>>> diff(lambda x: x**2 + x, 1.0, 3)
|
|
||||||
0.0
|
|
||||||
|
|
||||||
The exponential function is invariant under differentiation::
|
|
||||||
|
|
||||||
>>> nprint([diff(exp, 3, n) for n in range(5)])
|
|
||||||
[20.0855, 20.0855, 20.0855, 20.0855, 20.0855]
|
|
||||||
|
|
||||||
**Method**
|
|
||||||
|
|
||||||
One of two differentiation algorithms can be chosen with the
|
|
||||||
``method`` keyword argument. The two options are ``'step'``,
|
|
||||||
and ``'quad'``. The default method is ``'step'``.
|
|
||||||
|
|
||||||
``'step'``:
|
|
||||||
|
|
||||||
The derivative is computed using a finite difference
|
|
||||||
approximation, with a small step h. This requires n+1 function
|
|
||||||
evaluations and must be performed at (n+1) times the target
|
|
||||||
precision. Accordingly, f must support fast evaluation at high
|
|
||||||
precision.
|
|
||||||
|
|
||||||
``'quad'``:
|
|
||||||
|
|
||||||
The derivative is computed using complex
|
|
||||||
numerical integration. This requires a larger number of function
|
|
||||||
evaluations, but the advantage is that not much extra precision
|
|
||||||
is required. For high order derivatives, this method may thus
|
|
||||||
be faster if f is very expensive to evaluate at high precision.
|
|
||||||
|
|
||||||
With ``'quad'`` the result is likely to have a small imaginary
|
|
||||||
component even if the derivative is actually real::
|
|
||||||
|
|
||||||
>>> diff(sqrt, 1, method='quad') # doctest:+ELLIPSIS
|
|
||||||
(0.5 - 9.44...e-27j)
|
|
||||||
|
|
||||||
**Scale**
|
|
||||||
|
|
||||||
The scale option specifies the scale of variation of f. The step
|
|
||||||
size in the finite difference is taken to be approximately
|
|
||||||
eps*scale. Thus, for example if `f(x) = \cos(1000 x)`, the scale
|
|
||||||
should be set to 1/1000 and if `f(x) = \cos(x/1000)`, the scale
|
|
||||||
should be 1000. By default, scale = 1.
|
|
||||||
|
|
||||||
(In practice, the default scale will work even for `\cos(1000 x)` or
|
|
||||||
`\cos(x/1000)`. Changing this parameter is a good idea if the scale
|
|
||||||
is something *preposterous*.)
|
|
||||||
|
|
||||||
If numerical integration is used, the radius of integration is
|
|
||||||
taken to be equal to scale/2. Note that f must not have any
|
|
||||||
singularities within the circle of radius scale/2 centered around
|
|
||||||
x. If possible, a larger scale value is preferable because it
|
|
||||||
typically makes the integration faster and more accurate.
|
|
||||||
|
|
||||||
**Direction**
|
|
||||||
|
|
||||||
By default, :func:`diff` uses a central difference approximation.
|
|
||||||
This corresponds to direction=0. Alternatively, it can compute a
|
|
||||||
left difference (direction=-1) or right difference (direction=1).
|
|
||||||
This is useful for computing left- or right-sided derivatives
|
|
||||||
of nonsmooth functions:
|
|
||||||
|
|
||||||
>>> diff(abs, 0, direction=0)
|
|
||||||
0.0
|
|
||||||
>>> diff(abs, 0, direction=1)
|
|
||||||
1.0
|
|
||||||
>>> diff(abs, 0, direction=-1)
|
|
||||||
-1.0
|
|
||||||
|
|
||||||
More generally, if the direction is nonzero, a right difference
|
|
||||||
is computed where the step size is multiplied by sign(direction).
|
|
||||||
For example, with direction=+j, the derivative from the positive
|
|
||||||
imaginary direction will be computed.
|
|
||||||
|
|
||||||
This option only makes sense with method='step'. If integration
|
|
||||||
is used, it is assumed that f is analytic, implying that the
|
|
||||||
derivative is the same in all directions.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if n == 0:
|
|
||||||
return f(ctx.convert(x))
|
|
||||||
orig = ctx.prec
|
|
||||||
try:
|
|
||||||
if method == 'step':
|
|
||||||
ctx.prec = (orig+20) * (n+1)
|
|
||||||
h = ctx.ldexp(scale, -orig-10)
|
|
||||||
# Applying the finite difference formula recursively n times,
|
|
||||||
# we get a step sum weighted by a row of binomial coefficients
|
|
||||||
# Directed: steps x, x+h, ... x+n*h
|
|
||||||
if direction:
|
|
||||||
h *= ctx.sign(direction)
|
|
||||||
steps = xrange(n+1)
|
|
||||||
norm = h**n
|
|
||||||
# Central: steps x-n*h, x-(n-2)*h ..., x, ..., x+(n-2)*h, x+n*h
|
|
||||||
else:
|
|
||||||
steps = xrange(-n, n+1, 2)
|
|
||||||
norm = (2*h)**n
|
|
||||||
v = ctx.difference_delta([f(x+k*h) for k in steps], n)
|
|
||||||
v = v / norm
|
|
||||||
elif method == 'quad':
|
|
||||||
ctx.prec += 10
|
|
||||||
radius = ctx.mpf(scale)/2
|
|
||||||
def g(t):
|
|
||||||
rei = radius*ctx.expj(t)
|
|
||||||
z = x + rei
|
|
||||||
return f(z) / rei**n
|
|
||||||
d = ctx.quadts(g, [0, 2*ctx.pi])
|
|
||||||
v = d * ctx.factorial(n) / (2*ctx.pi)
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown method: %r" % method)
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
return +v
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def diffs(ctx, f, x, n=None, method='step', scale=1, direction=0):
|
|
||||||
r"""
|
|
||||||
Returns a generator that yields the sequence of derivatives
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
f(x), f'(x), f''(x), \ldots, f^{(k)}(x), \ldots
|
|
||||||
|
|
||||||
With ``method='step'``, :func:`diffs` uses only `O(k)`
|
|
||||||
function evaluations to generate the first `k` derivatives,
|
|
||||||
rather than the roughly `O(k^2)` evaluations
|
|
||||||
required if one calls :func:`diff` `k` separate times.
|
|
||||||
|
|
||||||
With `n < \infty`, the generator stops as soon as the
|
|
||||||
`n`-th derivative has been generated. If the exact number of
|
|
||||||
needed derivatives is known in advance, this is further
|
|
||||||
slightly more efficient.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> nprint(list(diffs(cos, 1, 5)))
|
|
||||||
[0.540302, -0.841471, -0.540302, 0.841471, 0.540302, -0.841471]
|
|
||||||
>>> for i, d in zip(range(6), diffs(cos, 1)): print i, d
|
|
||||||
...
|
|
||||||
0 0.54030230586814
|
|
||||||
1 -0.841470984807897
|
|
||||||
2 -0.54030230586814
|
|
||||||
3 0.841470984807897
|
|
||||||
4 0.54030230586814
|
|
||||||
5 -0.841470984807897
|
|
||||||
|
|
||||||
"""
|
|
||||||
if n is None:
|
|
||||||
n = ctx.inf
|
|
||||||
else:
|
|
||||||
n = int(n)
|
|
||||||
|
|
||||||
if method != 'step':
|
|
||||||
k = 0
|
|
||||||
while k < n:
|
|
||||||
yield ctx.diff(f, x, k)
|
|
||||||
k += 1
|
|
||||||
return
|
|
||||||
|
|
||||||
targetprec = ctx.prec
|
|
||||||
|
|
||||||
def getvalues(m):
|
|
||||||
callprec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec = workprec = (targetprec+20) * (m+1)
|
|
||||||
h = ctx.ldexp(scale, -targetprec-10)
|
|
||||||
if direction:
|
|
||||||
h *= ctx.sign(direction)
|
|
||||||
y = [f(x+h*k) for k in xrange(m+1)]
|
|
||||||
hnorm = h
|
|
||||||
else:
|
|
||||||
y = [f(x+h*k) for k in xrange(-m, m+1, 2)]
|
|
||||||
hnorm = 2*h
|
|
||||||
return y, hnorm, workprec
|
|
||||||
finally:
|
|
||||||
ctx.prec = callprec
|
|
||||||
|
|
||||||
yield f(ctx.convert(x))
|
|
||||||
if n < 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
if n == ctx.inf:
|
|
||||||
A, B = 1, 2
|
|
||||||
else:
|
|
||||||
A, B = 1, n+1
|
|
||||||
|
|
||||||
while 1:
|
|
||||||
y, hnorm, workprec = getvalues(B)
|
|
||||||
for k in xrange(A, B):
|
|
||||||
try:
|
|
||||||
callprec = ctx.prec
|
|
||||||
ctx.prec = workprec
|
|
||||||
d = ctx.difference_delta(y, k) / hnorm**k
|
|
||||||
finally:
|
|
||||||
ctx.prec = callprec
|
|
||||||
yield +d
|
|
||||||
if k >= n:
|
|
||||||
return
|
|
||||||
A, B = B, int(A*1.4+1)
|
|
||||||
B = min(B, n)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def differint(ctx, f, x, n=1, x0=0):
|
|
||||||
r"""
|
|
||||||
Calculates the Riemann-Liouville differintegral, or fractional
|
|
||||||
derivative, defined by
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\,_{x_0}{\mathbb{D}}^n_xf(x) \frac{1}{\Gamma(m-n)} \frac{d^m}{dx^m}
|
|
||||||
\int_{x_0}^{x}(x-t)^{m-n-1}f(t)dt
|
|
||||||
|
|
||||||
where `f` is a given (presumably well-behaved) function,
|
|
||||||
`x` is the evaluation point, `n` is the order, and `x_0` is
|
|
||||||
the reference point of integration (`m` is an arbitrary
|
|
||||||
parameter selected automatically).
|
|
||||||
|
|
||||||
With `n = 1`, this is just the standard derivative `f'(x)`; with `n = 2`,
|
|
||||||
the second derivative `f''(x)`, etc. With `n = -1`, it gives
|
|
||||||
`\int_{x_0}^x f(t) dt`, with `n = -2`
|
|
||||||
it gives `\int_{x_0}^x \left( \int_{x_0}^t f(u) du \right) dt`, etc.
|
|
||||||
|
|
||||||
As `n` is permitted to be any number, this operator generalizes
|
|
||||||
iterated differentiation and iterated integration to a single
|
|
||||||
operator with a continuous order parameter.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
There is an exact formula for the fractional derivative of a
|
|
||||||
monomial `x^p`, which may be used as a reference. For example,
|
|
||||||
the following gives a half-derivative (order 0.5)::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> x = mpf(3); p = 2; n = 0.5
|
|
||||||
>>> differint(lambda t: t**p, x, n)
|
|
||||||
7.81764019044672
|
|
||||||
>>> gamma(p+1)/gamma(p-n+1) * x**(p-n)
|
|
||||||
7.81764019044672
|
|
||||||
|
|
||||||
Another useful test function is the exponential function, whose
|
|
||||||
integration / differentiation formula easy generalizes
|
|
||||||
to arbitrary order. Here we first compute a third derivative,
|
|
||||||
and then a triply nested integral. (The reference point `x_0`
|
|
||||||
is set to `-\infty` to avoid nonzero endpoint terms.)::
|
|
||||||
|
|
||||||
>>> differint(lambda x: exp(pi*x), -1.5, 3)
|
|
||||||
0.278538406900792
|
|
||||||
>>> exp(pi*-1.5) * pi**3
|
|
||||||
0.278538406900792
|
|
||||||
>>> differint(lambda x: exp(pi*x), 3.5, -3, -inf)
|
|
||||||
1922.50563031149
|
|
||||||
>>> exp(pi*3.5) / pi**3
|
|
||||||
1922.50563031149
|
|
||||||
|
|
||||||
However, for noninteger `n`, the differentiation formula for the
|
|
||||||
exponential function must be modified to give the same result as the
|
|
||||||
Riemann-Liouville differintegral::
|
|
||||||
|
|
||||||
>>> x = mpf(3.5)
|
|
||||||
>>> c = pi
|
|
||||||
>>> n = 1+2*j
|
|
||||||
>>> differint(lambda x: exp(c*x), x, n)
|
|
||||||
(-123295.005390743 + 140955.117867654j)
|
|
||||||
>>> x**(-n) * exp(c)**x * (x*c)**n * gammainc(-n, 0, x*c) / gamma(-n)
|
|
||||||
(-123295.005390743 + 140955.117867654j)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
m = max(int(ctx.ceil(ctx.re(n)))+1, 1)
|
|
||||||
r = m-n-1
|
|
||||||
g = lambda x: ctx.quad(lambda t: (x-t)**r * f(t), [x0, x])
|
|
||||||
return ctx.diff(g, x, m) / ctx.gamma(m-n)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def diffun(ctx, f, n=1, **options):
|
|
||||||
"""
|
|
||||||
Given a function f, returns a function g(x) that evaluates the nth
|
|
||||||
derivative f^(n)(x)::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> cos2 = diffun(sin)
|
|
||||||
>>> sin2 = diffun(sin, 4)
|
|
||||||
>>> cos(1.3), cos2(1.3)
|
|
||||||
(0.267498828624587, 0.267498828624587)
|
|
||||||
>>> sin(1.3), sin2(1.3)
|
|
||||||
(0.963558185417193, 0.963558185417193)
|
|
||||||
|
|
||||||
The function f must support arbitrary precision evaluation.
|
|
||||||
See :func:`diff` for additional details and supported
|
|
||||||
keyword options.
|
|
||||||
"""
|
|
||||||
if n == 0:
|
|
||||||
return f
|
|
||||||
def g(x):
|
|
||||||
return ctx.diff(f, x, n, **options)
|
|
||||||
return g
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def taylor(ctx, f, x, n, **options):
|
|
||||||
r"""
|
|
||||||
Produces a degree-`n` Taylor polynomial around the point `x` of the
|
|
||||||
given function `f`. The coefficients are returned as a list.
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> nprint(chop(taylor(sin, 0, 5)))
|
|
||||||
[0.0, 1.0, 0.0, -0.166667, 0.0, 0.00833333]
|
|
||||||
|
|
||||||
The coefficients are computed using high-order numerical
|
|
||||||
differentiation. The function must be possible to evaluate
|
|
||||||
to arbitrary precision. See :func:`diff` for additional details
|
|
||||||
and supported keyword options.
|
|
||||||
|
|
||||||
Note that to evaluate the Taylor polynomial as an approximation
|
|
||||||
of `f`, e.g. with :func:`polyval`, the coefficients must be reversed,
|
|
||||||
and the point of the Taylor expansion must be subtracted from
|
|
||||||
the argument:
|
|
||||||
|
|
||||||
>>> p = taylor(exp, 2.0, 10)
|
|
||||||
>>> polyval(p[::-1], 2.5 - 2.0)
|
|
||||||
12.1824939606092
|
|
||||||
>>> exp(2.5)
|
|
||||||
12.1824939607035
|
|
||||||
|
|
||||||
"""
|
|
||||||
return [d/ctx.factorial(i) for i, d in enumerate(ctx.diffs(f, x, n, **options))]
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def pade(ctx, a, L, M):
|
|
||||||
r"""
|
|
||||||
Computes a Pade approximation of degree `(L, M)` to a function.
|
|
||||||
Given at least `L+M+1` Taylor coefficients `a` approximating
|
|
||||||
a function `A(x)`, :func:`pade` returns coefficients of
|
|
||||||
polynomials `P, Q` satisfying
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
P = \sum_{k=0}^L p_k x^k
|
|
||||||
|
|
||||||
Q = \sum_{k=0}^M q_k x^k
|
|
||||||
|
|
||||||
Q_0 = 1
|
|
||||||
|
|
||||||
A(x) Q(x) = P(x) + O(x^{L+M+1})
|
|
||||||
|
|
||||||
`P(x)/Q(x)` can provide a good approximation to an analytic function
|
|
||||||
beyond the radius of convergence of its Taylor series (example
|
|
||||||
from G.A. Baker 'Essentials of Pade Approximants' Academic Press,
|
|
||||||
Ch.1A)::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> one = mpf(1)
|
|
||||||
>>> def f(x):
|
|
||||||
... return sqrt((one + 2*x)/(one + x))
|
|
||||||
...
|
|
||||||
>>> a = taylor(f, 0, 6)
|
|
||||||
>>> p, q = pade(a, 3, 3)
|
|
||||||
>>> x = 10
|
|
||||||
>>> polyval(p[::-1], x)/polyval(q[::-1], x)
|
|
||||||
1.38169105566806
|
|
||||||
>>> f(x)
|
|
||||||
1.38169855941551
|
|
||||||
|
|
||||||
"""
|
|
||||||
# To determine L+1 coefficients of P and M coefficients of Q
|
|
||||||
# L+M+1 coefficients of A must be provided
|
|
||||||
assert(len(a) >= L+M+1)
|
|
||||||
|
|
||||||
if M == 0:
|
|
||||||
if L == 0:
|
|
||||||
return [ctx.one], [ctx.one]
|
|
||||||
else:
|
|
||||||
return a[:L+1], [ctx.one]
|
|
||||||
|
|
||||||
# Solve first
|
|
||||||
# a[L]*q[1] + ... + a[L-M+1]*q[M] = -a[L+1]
|
|
||||||
# ...
|
|
||||||
# a[L+M-1]*q[1] + ... + a[L]*q[M] = -a[L+M]
|
|
||||||
A = ctx.matrix(M)
|
|
||||||
for j in range(M):
|
|
||||||
for i in range(min(M, L+j+1)):
|
|
||||||
A[j, i] = a[L+j-i]
|
|
||||||
v = -ctx.matrix(a[(L+1):(L+M+1)])
|
|
||||||
x = ctx.lu_solve(A, v)
|
|
||||||
q = [ctx.one] + list(x)
|
|
||||||
# compute p
|
|
||||||
p = [0]*(L+1)
|
|
||||||
for i in range(L+1):
|
|
||||||
s = a[i]
|
|
||||||
for j in range(1, min(M,i) + 1):
|
|
||||||
s += q[j]*a[i-j]
|
|
||||||
p[i] = s
|
|
||||||
return p, q
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,287 +0,0 @@
|
||||||
from bisect import bisect
|
|
||||||
|
|
||||||
class ODEMethods(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def ode_taylor(ctx, derivs, x0, y0, tol_prec, n):
|
|
||||||
h = tol = ctx.ldexp(1, -tol_prec)
|
|
||||||
dim = len(y0)
|
|
||||||
xs = [x0]
|
|
||||||
ys = [y0]
|
|
||||||
x = x0
|
|
||||||
y = y0
|
|
||||||
orig = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec = orig*(1+n)
|
|
||||||
# Use n steps with Euler's method to get
|
|
||||||
# evaluation points for derivatives
|
|
||||||
for i in range(n):
|
|
||||||
fxy = derivs(x, y)
|
|
||||||
y = [y[i]+h*fxy[i] for i in xrange(len(y))]
|
|
||||||
x += h
|
|
||||||
xs.append(x)
|
|
||||||
ys.append(y)
|
|
||||||
# Compute derivatives
|
|
||||||
ser = [[] for d in range(dim)]
|
|
||||||
for j in range(n+1):
|
|
||||||
s = [0]*dim
|
|
||||||
b = (-1) ** (j & 1)
|
|
||||||
k = 1
|
|
||||||
for i in range(j+1):
|
|
||||||
for d in range(dim):
|
|
||||||
s[d] += b * ys[i][d]
|
|
||||||
b = (b * (j-k+1)) // (-k)
|
|
||||||
k += 1
|
|
||||||
scale = h**(-j) / ctx.fac(j)
|
|
||||||
for d in range(dim):
|
|
||||||
s[d] = s[d] * scale
|
|
||||||
ser[d].append(s[d])
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
# Estimate radius for which we can get full accuracy.
|
|
||||||
# XXX: do this right for zeros
|
|
||||||
radius = ctx.one
|
|
||||||
for ts in ser:
|
|
||||||
if ts[-1]:
|
|
||||||
radius = min(radius, ctx.nthroot(tol/abs(ts[-1]), n))
|
|
||||||
radius /= 2 # XXX
|
|
||||||
return ser, x0+radius
|
|
||||||
|
|
||||||
def odefun(ctx, F, x0, y0, tol=None, degree=None, method='taylor', verbose=False):
|
|
||||||
r"""
|
|
||||||
Returns a function `y(x) = [y_0(x), y_1(x), \ldots, y_n(x)]`
|
|
||||||
that is a numerical solution of the `n+1`-dimensional first-order
|
|
||||||
ordinary differential equation (ODE) system
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
y_0'(x) = F_0(x, [y_0(x), y_1(x), \ldots, y_n(x)])
|
|
||||||
|
|
||||||
y_1'(x) = F_1(x, [y_0(x), y_1(x), \ldots, y_n(x)])
|
|
||||||
|
|
||||||
\vdots
|
|
||||||
|
|
||||||
y_n'(x) = F_n(x, [y_0(x), y_1(x), \ldots, y_n(x)])
|
|
||||||
|
|
||||||
The derivatives are specified by the vector-valued function
|
|
||||||
*F* that evaluates
|
|
||||||
`[y_0', \ldots, y_n'] = F(x, [y_0, \ldots, y_n])`.
|
|
||||||
The initial point `x_0` is specified by the scalar argument *x0*,
|
|
||||||
and the initial value `y(x_0) = [y_0(x_0), \ldots, y_n(x_0)]` is
|
|
||||||
specified by the vector argument *y0*.
|
|
||||||
|
|
||||||
For convenience, if the system is one-dimensional, you may optionally
|
|
||||||
provide just a scalar value for *y0*. In this case, *F* should accept
|
|
||||||
a scalar *y* argument and return a scalar. The solution function
|
|
||||||
*y* will return scalar values instead of length-1 vectors.
|
|
||||||
|
|
||||||
Evaluation of the solution function `y(x)` is permitted
|
|
||||||
for any `x \ge x_0`.
|
|
||||||
|
|
||||||
A high-order ODE can be solved by transforming it into first-order
|
|
||||||
vector form. This transformation is described in standard texts
|
|
||||||
on ODEs. Examples will also be given below.
|
|
||||||
|
|
||||||
**Options, speed and accuracy**
|
|
||||||
|
|
||||||
By default, :func:`odefun` uses a high-order Taylor series
|
|
||||||
method. For reasonably well-behaved problems, the solution will
|
|
||||||
be fully accurate to within the working precision. Note that
|
|
||||||
*F* must be possible to evaluate to very high precision
|
|
||||||
for the generation of Taylor series to work.
|
|
||||||
|
|
||||||
To get a faster but less accurate solution, you can set a large
|
|
||||||
value for *tol* (which defaults roughly to *eps*). If you just
|
|
||||||
want to plot the solution or perform a basic simulation,
|
|
||||||
*tol = 0.01* is likely sufficient.
|
|
||||||
|
|
||||||
The *degree* argument controls the degree of the solver (with
|
|
||||||
*method='taylor'*, this is the degree of the Taylor series
|
|
||||||
expansion). A higher degree means that a longer step can be taken
|
|
||||||
before a new local solution must be generated from *F*,
|
|
||||||
meaning that fewer steps are required to get from `x_0` to a given
|
|
||||||
`x_1`. On the other hand, a higher degree also means that each
|
|
||||||
local solution becomes more expensive (i.e., more evaluations of
|
|
||||||
*F* are required per step, and at higher precision).
|
|
||||||
|
|
||||||
The optimal setting therefore involves a tradeoff. Generally,
|
|
||||||
decreasing the *degree* for Taylor series is likely to give faster
|
|
||||||
solution at low precision, while increasing is likely to be better
|
|
||||||
at higher precision.
|
|
||||||
|
|
||||||
The function
|
|
||||||
object returned by :func:`odefun` caches the solutions at all step
|
|
||||||
points and uses polynomial interpolation between step points.
|
|
||||||
Therefore, once `y(x_1)` has been evaluated for some `x_1`,
|
|
||||||
`y(x)` can be evaluated very quickly for any `x_0 \le x \le x_1`.
|
|
||||||
and continuing the evaluation up to `x_2 > x_1` is also fast.
|
|
||||||
|
|
||||||
**Examples of first-order ODEs**
|
|
||||||
|
|
||||||
We will solve the standard test problem `y'(x) = y(x), y(0) = 1`
|
|
||||||
which has explicit solution `y(x) = \exp(x)`::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> f = odefun(lambda x, y: y, 0, 1)
|
|
||||||
>>> for x in [0, 1, 2.5]:
|
|
||||||
... print f(x), exp(x)
|
|
||||||
...
|
|
||||||
1.0 1.0
|
|
||||||
2.71828182845905 2.71828182845905
|
|
||||||
12.1824939607035 12.1824939607035
|
|
||||||
|
|
||||||
The solution with high precision::
|
|
||||||
|
|
||||||
>>> mp.dps = 50
|
|
||||||
>>> f = odefun(lambda x, y: y, 0, 1)
|
|
||||||
>>> f(1)
|
|
||||||
2.7182818284590452353602874713526624977572470937
|
|
||||||
>>> exp(1)
|
|
||||||
2.7182818284590452353602874713526624977572470937
|
|
||||||
|
|
||||||
Using the more general vectorized form, the test problem
|
|
||||||
can be input as (note that *f* returns a 1-element vector)::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> f = odefun(lambda x, y: [y[0]], 0, [1])
|
|
||||||
>>> f(1)
|
|
||||||
[2.71828182845905]
|
|
||||||
|
|
||||||
:func:`odefun` can solve nonlinear ODEs, which are generally
|
|
||||||
impossible (and at best difficult) to solve analytically. As
|
|
||||||
an example of a nonlinear ODE, we will solve `y'(x) = x \sin(y(x))`
|
|
||||||
for `y(0) = \pi/2`. An exact solution happens to be known
|
|
||||||
for this problem, and is given by
|
|
||||||
`y(x) = 2 \tan^{-1}\left(\exp\left(x^2/2\right)\right)`::
|
|
||||||
|
|
||||||
>>> f = odefun(lambda x, y: x*sin(y), 0, pi/2)
|
|
||||||
>>> for x in [2, 5, 10]:
|
|
||||||
... print f(x), 2*atan(exp(mpf(x)**2/2))
|
|
||||||
...
|
|
||||||
2.87255666284091 2.87255666284091
|
|
||||||
3.14158520028345 3.14158520028345
|
|
||||||
3.14159265358979 3.14159265358979
|
|
||||||
|
|
||||||
If `F` is independent of `y`, an ODE can be solved using direct
|
|
||||||
integration. We can therefore obtain a reference solution with
|
|
||||||
:func:`quad`::
|
|
||||||
|
|
||||||
>>> f = lambda x: (1+x**2)/(1+x**3)
|
|
||||||
>>> g = odefun(lambda x, y: f(x), pi, 0)
|
|
||||||
>>> g(2*pi)
|
|
||||||
0.72128263801696
|
|
||||||
>>> quad(f, [pi, 2*pi])
|
|
||||||
0.72128263801696
|
|
||||||
|
|
||||||
**Examples of second-order ODEs**
|
|
||||||
|
|
||||||
We will solve the harmonic oscillator equation `y''(x) + y(x) = 0`.
|
|
||||||
To do this, we introduce the helper functions `y_0 = y, y_1 = y_0'`
|
|
||||||
whereby the original equation can be written as `y_1' + y_0' = 0`. Put
|
|
||||||
together, we get the first-order, two-dimensional vector ODE
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\begin{cases}
|
|
||||||
y_0' = y_1 \\
|
|
||||||
y_1' = -y_0
|
|
||||||
\end{cases}
|
|
||||||
|
|
||||||
To get a well-defined IVP, we need two initial values. With
|
|
||||||
`y(0) = y_0(0) = 1` and `-y'(0) = y_1(0) = 0`, the problem will of
|
|
||||||
course be solved by `y(x) = y_0(x) = \cos(x)` and
|
|
||||||
`-y'(x) = y_1(x) = \sin(x)`. We check this::
|
|
||||||
|
|
||||||
>>> f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
|
|
||||||
>>> for x in [0, 1, 2.5, 10]:
|
|
||||||
... nprint(f(x), 15)
|
|
||||||
... nprint([cos(x), sin(x)], 15)
|
|
||||||
... print "---"
|
|
||||||
...
|
|
||||||
[1.0, 0.0]
|
|
||||||
[1.0, 0.0]
|
|
||||||
---
|
|
||||||
[0.54030230586814, 0.841470984807897]
|
|
||||||
[0.54030230586814, 0.841470984807897]
|
|
||||||
---
|
|
||||||
[-0.801143615546934, 0.598472144103957]
|
|
||||||
[-0.801143615546934, 0.598472144103957]
|
|
||||||
---
|
|
||||||
[-0.839071529076452, -0.54402111088937]
|
|
||||||
[-0.839071529076452, -0.54402111088937]
|
|
||||||
---
|
|
||||||
|
|
||||||
Note that we get both the sine and the cosine solutions
|
|
||||||
simultaneously.
|
|
||||||
|
|
||||||
**TODO**
|
|
||||||
|
|
||||||
* Better automatic choice of degree and step size
|
|
||||||
* Make determination of Taylor series convergence radius
|
|
||||||
more robust
|
|
||||||
* Allow solution for `x < x_0`
|
|
||||||
* Allow solution for complex `x`
|
|
||||||
* Test for difficult (ill-conditioned) problems
|
|
||||||
* Implement Runge-Kutta and other algorithms
|
|
||||||
|
|
||||||
"""
|
|
||||||
if tol:
|
|
||||||
tol_prec = int(-ctx.log(tol, 2))+10
|
|
||||||
else:
|
|
||||||
tol_prec = ctx.prec+10
|
|
||||||
degree = degree or (3 + int(3*ctx.dps/2.))
|
|
||||||
workprec = ctx.prec + 40
|
|
||||||
try:
|
|
||||||
len(y0)
|
|
||||||
return_vector = True
|
|
||||||
except TypeError:
|
|
||||||
F_ = F
|
|
||||||
F = lambda x, y: [F_(x, y[0])]
|
|
||||||
y0 = [y0]
|
|
||||||
return_vector = False
|
|
||||||
ser, xb = ode_taylor(ctx, F, x0, y0, tol_prec, degree)
|
|
||||||
series_boundaries = [x0, xb]
|
|
||||||
series_data = [(ser, x0, xb)]
|
|
||||||
# We will be working with vectors of Taylor series
|
|
||||||
def mpolyval(ser, a):
|
|
||||||
return [ctx.polyval(s[::-1], a) for s in ser]
|
|
||||||
# Find nearest expansion point; compute if necessary
|
|
||||||
def get_series(x):
|
|
||||||
if x < x0:
|
|
||||||
raise ValueError
|
|
||||||
n = bisect(series_boundaries, x)
|
|
||||||
if n < len(series_boundaries):
|
|
||||||
return series_data[n-1]
|
|
||||||
while 1:
|
|
||||||
ser, xa, xb = series_data[-1]
|
|
||||||
if verbose:
|
|
||||||
print "Computing Taylor series for [%f, %f]" % (xa, xb)
|
|
||||||
y = mpolyval(ser, xb-xa)
|
|
||||||
xa = xb
|
|
||||||
ser, xb = ode_taylor(ctx, F, xb, y, tol_prec, degree)
|
|
||||||
series_boundaries.append(xb)
|
|
||||||
series_data.append((ser, xa, xb))
|
|
||||||
if x <= xb:
|
|
||||||
return series_data[-1]
|
|
||||||
# Evaluation function
|
|
||||||
def interpolant(x):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
orig = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec = workprec
|
|
||||||
ser, xa, xb = get_series(x)
|
|
||||||
y = mpolyval(ser, x-xa)
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
if return_vector:
|
|
||||||
return [+yk for yk in y]
|
|
||||||
else:
|
|
||||||
return +y[0]
|
|
||||||
return interpolant
|
|
||||||
|
|
||||||
ODEMethods.odefun = odefun
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,189 +0,0 @@
|
||||||
from calculus import defun
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
# Polynomials #
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
|
|
||||||
# XXX: extra precision
|
|
||||||
@defun
|
|
||||||
def polyval(ctx, coeffs, x, derivative=False):
|
|
||||||
r"""
|
|
||||||
Given coefficients `[c_n, \ldots, c_2, c_1, c_0]` and a number `x`,
|
|
||||||
:func:`polyval` evaluates the polynomial
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
P(x) = c_n x^n + \ldots + c_2 x^2 + c_1 x + c_0.
|
|
||||||
|
|
||||||
If *derivative=True* is set, :func:`polyval` simultaneously
|
|
||||||
evaluates `P(x)` with the derivative, `P'(x)`, and returns the
|
|
||||||
tuple `(P(x), P'(x))`.
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.pretty = True
|
|
||||||
>>> polyval([3, 0, 2], 0.5)
|
|
||||||
2.75
|
|
||||||
>>> polyval([3, 0, 2], 0.5, derivative=True)
|
|
||||||
(2.75, 3.0)
|
|
||||||
|
|
||||||
The coefficients and the evaluation point may be any combination
|
|
||||||
of real or complex numbers.
|
|
||||||
"""
|
|
||||||
if not coeffs:
|
|
||||||
return ctx.zero
|
|
||||||
p = ctx.convert(coeffs[0])
|
|
||||||
q = ctx.zero
|
|
||||||
for c in coeffs[1:]:
|
|
||||||
if derivative:
|
|
||||||
q = p + x*q
|
|
||||||
p = c + x*p
|
|
||||||
if derivative:
|
|
||||||
return p, q
|
|
||||||
else:
|
|
||||||
return p
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def polyroots(ctx, coeffs, maxsteps=50, cleanup=True, extraprec=10, error=False):
|
|
||||||
"""
|
|
||||||
Computes all roots (real or complex) of a given polynomial. The roots are
|
|
||||||
returned as a sorted list, where real roots appear first followed by
|
|
||||||
complex conjugate roots as adjacent elements. The polynomial should be
|
|
||||||
given as a list of coefficients, in the format used by :func:`polyval`.
|
|
||||||
The leading coefficient must be nonzero.
|
|
||||||
|
|
||||||
With *error=True*, :func:`polyroots` returns a tuple *(roots, err)* where
|
|
||||||
*err* is an estimate of the maximum error among the computed roots.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Finding the three real roots of `x^3 - x^2 - 14x + 24`::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> nprint(polyroots([1,-1,-14,24]), 4)
|
|
||||||
[-4.0, 2.0, 3.0]
|
|
||||||
|
|
||||||
Finding the two complex conjugate roots of `4x^2 + 3x + 2`, with an
|
|
||||||
error estimate::
|
|
||||||
|
|
||||||
>>> roots, err = polyroots([4,3,2], error=True)
|
|
||||||
>>> for r in roots:
|
|
||||||
... print r
|
|
||||||
...
|
|
||||||
(-0.375 + 0.59947894041409j)
|
|
||||||
(-0.375 - 0.59947894041409j)
|
|
||||||
>>>
|
|
||||||
>>> err
|
|
||||||
2.22044604925031e-16
|
|
||||||
>>>
|
|
||||||
>>> polyval([4,3,2], roots[0])
|
|
||||||
(2.22044604925031e-16 + 0.0j)
|
|
||||||
>>> polyval([4,3,2], roots[1])
|
|
||||||
(2.22044604925031e-16 + 0.0j)
|
|
||||||
|
|
||||||
The following example computes all the 5th roots of unity; that is,
|
|
||||||
the roots of `x^5 - 1`::
|
|
||||||
|
|
||||||
>>> mp.dps = 20
|
|
||||||
>>> for r in polyroots([1, 0, 0, 0, 0, -1]):
|
|
||||||
... print r
|
|
||||||
...
|
|
||||||
1.0
|
|
||||||
(-0.8090169943749474241 + 0.58778525229247312917j)
|
|
||||||
(-0.8090169943749474241 - 0.58778525229247312917j)
|
|
||||||
(0.3090169943749474241 + 0.95105651629515357212j)
|
|
||||||
(0.3090169943749474241 - 0.95105651629515357212j)
|
|
||||||
|
|
||||||
**Precision and conditioning**
|
|
||||||
|
|
||||||
Provided there are no repeated roots, :func:`polyroots` can typically
|
|
||||||
compute all roots of an arbitrary polynomial to high precision::
|
|
||||||
|
|
||||||
>>> mp.dps = 60
|
|
||||||
>>> for r in polyroots([1, 0, -10, 0, 1]):
|
|
||||||
... print r
|
|
||||||
...
|
|
||||||
-3.14626436994197234232913506571557044551247712918732870123249
|
|
||||||
-0.317837245195782244725757617296174288373133378433432554879127
|
|
||||||
0.317837245195782244725757617296174288373133378433432554879127
|
|
||||||
3.14626436994197234232913506571557044551247712918732870123249
|
|
||||||
>>>
|
|
||||||
>>> sqrt(3) + sqrt(2)
|
|
||||||
3.14626436994197234232913506571557044551247712918732870123249
|
|
||||||
>>> sqrt(3) - sqrt(2)
|
|
||||||
0.317837245195782244725757617296174288373133378433432554879127
|
|
||||||
|
|
||||||
**Algorithm**
|
|
||||||
|
|
||||||
:func:`polyroots` implements the Durand-Kerner method [1], which
|
|
||||||
uses complex arithmetic to locate all roots simultaneously.
|
|
||||||
The Durand-Kerner method can be viewed as approximately performing
|
|
||||||
simultaneous Newton iteration for all the roots. In particular,
|
|
||||||
the convergence to simple roots is quadratic, just like Newton's
|
|
||||||
method.
|
|
||||||
|
|
||||||
Although all roots are internally calculated using complex arithmetic,
|
|
||||||
any root found to have an imaginary part smaller than the estimated
|
|
||||||
numerical error is truncated to a real number. Real roots are placed
|
|
||||||
first in the returned list, sorted by value. The remaining complex
|
|
||||||
roots are sorted by real their parts so that conjugate roots end up
|
|
||||||
next to each other.
|
|
||||||
|
|
||||||
**References**
|
|
||||||
|
|
||||||
1. http://en.wikipedia.org/wiki/Durand-Kerner_method
|
|
||||||
|
|
||||||
"""
|
|
||||||
if len(coeffs) <= 1:
|
|
||||||
if not coeffs or not coeffs[0]:
|
|
||||||
raise ValueError("Input to polyroots must not be the zero polynomial")
|
|
||||||
# Constant polynomial with no roots
|
|
||||||
return []
|
|
||||||
|
|
||||||
orig = ctx.prec
|
|
||||||
weps = +ctx.eps
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
tol = ctx.eps * 128
|
|
||||||
deg = len(coeffs) - 1
|
|
||||||
# Must be monic
|
|
||||||
lead = ctx.convert(coeffs[0])
|
|
||||||
if lead == 1:
|
|
||||||
coeffs = map(ctx.convert, coeffs)
|
|
||||||
else:
|
|
||||||
coeffs = [c/lead for c in coeffs]
|
|
||||||
f = lambda x: ctx.polyval(coeffs, x)
|
|
||||||
roots = [ctx.mpc((0.4+0.9j)**n) for n in xrange(deg)]
|
|
||||||
err = [ctx.one for n in xrange(deg)]
|
|
||||||
# Durand-Kerner iteration until convergence
|
|
||||||
for step in xrange(maxsteps):
|
|
||||||
if abs(max(err)) < tol:
|
|
||||||
break
|
|
||||||
for i in xrange(deg):
|
|
||||||
if not abs(err[i]) < tol:
|
|
||||||
p = roots[i]
|
|
||||||
x = f(p)
|
|
||||||
for j in range(deg):
|
|
||||||
if i != j:
|
|
||||||
try:
|
|
||||||
x /= (p-roots[j])
|
|
||||||
except ZeroDivisionError:
|
|
||||||
continue
|
|
||||||
roots[i] = p - x
|
|
||||||
err[i] = abs(x)
|
|
||||||
# Remove small imaginary parts
|
|
||||||
if cleanup:
|
|
||||||
for i in xrange(deg):
|
|
||||||
if abs(ctx._im(roots[i])) < weps:
|
|
||||||
roots[i] = roots[i].real
|
|
||||||
elif abs(ctx._re(roots[i])) < weps:
|
|
||||||
roots[i] = roots[i].imag * 1j
|
|
||||||
roots.sort(key=lambda x: (abs(ctx._im(x)), ctx._re(x)))
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
if error:
|
|
||||||
err = max(err)
|
|
||||||
err = max(err, ctx.ldexp(1, -orig+1))
|
|
||||||
return [+r for r in roots], +err
|
|
||||||
else:
|
|
||||||
return [+r for r in roots]
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,8 +0,0 @@
|
||||||
# The py library is part of the "py.test" testing suite (python-codespeak-lib
|
|
||||||
# on Debian), see http://codespeak.net/py/
|
|
||||||
|
|
||||||
import py
|
|
||||||
|
|
||||||
#this makes py.test put mpath directory into the sys.path, so that we can
|
|
||||||
#"import mpmath" from tests nicely
|
|
||||||
rootdir = py.magic.autopath().dirpath()
|
|
||||||
|
|
@ -1,324 +0,0 @@
|
||||||
from operator import gt, lt
|
|
||||||
|
|
||||||
from functions.functions import SpecialFunctions
|
|
||||||
from functions.rszeta import RSCache
|
|
||||||
from calculus.quadrature import QuadratureMethods
|
|
||||||
from calculus.calculus import CalculusMethods
|
|
||||||
from calculus.optimization import OptimizationMethods
|
|
||||||
from calculus.odes import ODEMethods
|
|
||||||
from matrices.matrices import MatrixMethods
|
|
||||||
from matrices.calculus import MatrixCalculusMethods
|
|
||||||
from matrices.linalg import LinearAlgebraMethods
|
|
||||||
from identification import IdentificationMethods
|
|
||||||
from visualization import VisualizationMethods
|
|
||||||
|
|
||||||
import libmp
|
|
||||||
|
|
||||||
class Context(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class StandardBaseContext(Context,
|
|
||||||
SpecialFunctions,
|
|
||||||
RSCache,
|
|
||||||
QuadratureMethods,
|
|
||||||
CalculusMethods,
|
|
||||||
MatrixMethods,
|
|
||||||
MatrixCalculusMethods,
|
|
||||||
LinearAlgebraMethods,
|
|
||||||
IdentificationMethods,
|
|
||||||
OptimizationMethods,
|
|
||||||
ODEMethods,
|
|
||||||
VisualizationMethods):
|
|
||||||
|
|
||||||
NoConvergence = libmp.NoConvergence
|
|
||||||
ComplexResult = libmp.ComplexResult
|
|
||||||
|
|
||||||
def __init__(ctx):
|
|
||||||
ctx._aliases = {}
|
|
||||||
# Call those that need preinitialization (e.g. for wrappers)
|
|
||||||
SpecialFunctions.__init__(ctx)
|
|
||||||
RSCache.__init__(ctx)
|
|
||||||
QuadratureMethods.__init__(ctx)
|
|
||||||
CalculusMethods.__init__(ctx)
|
|
||||||
MatrixMethods.__init__(ctx)
|
|
||||||
|
|
||||||
def _init_aliases(ctx):
|
|
||||||
for alias, value in ctx._aliases.items():
|
|
||||||
try:
|
|
||||||
setattr(ctx, alias, getattr(ctx, value))
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
_fixed_precision = False
|
|
||||||
|
|
||||||
# XXX
|
|
||||||
verbose = False
|
|
||||||
|
|
||||||
def warn(ctx, msg):
|
|
||||||
print "Warning:", msg
|
|
||||||
|
|
||||||
def bad_domain(ctx, msg):
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def _re(ctx, x):
|
|
||||||
if hasattr(x, "real"):
|
|
||||||
return x.real
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _im(ctx, x):
|
|
||||||
if hasattr(x, "imag"):
|
|
||||||
return x.imag
|
|
||||||
return ctx.zero
|
|
||||||
|
|
||||||
def chop(ctx, x, tol=None):
|
|
||||||
"""
|
|
||||||
Chops off small real or imaginary parts, or converts
|
|
||||||
numbers close to zero to exact zeros. The input can be a
|
|
||||||
single number or an iterable::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> chop(5+1e-10j, tol=1e-9)
|
|
||||||
mpf('5.0')
|
|
||||||
>>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2]))
|
|
||||||
[1.0, 0.0, 3.0, -4.0, 2.0]
|
|
||||||
|
|
||||||
The tolerance defaults to ``100*eps``.
|
|
||||||
"""
|
|
||||||
if tol is None:
|
|
||||||
tol = 100*ctx.eps
|
|
||||||
try:
|
|
||||||
x = ctx.convert(x)
|
|
||||||
absx = abs(x)
|
|
||||||
if abs(x) < tol:
|
|
||||||
return ctx.zero
|
|
||||||
if ctx._is_complex_type(x):
|
|
||||||
if abs(x.imag) < min(tol, absx*tol):
|
|
||||||
return x.real
|
|
||||||
if abs(x.real) < min(tol, absx*tol):
|
|
||||||
return ctx.mpc(0, x.imag)
|
|
||||||
except TypeError:
|
|
||||||
if isinstance(x, ctx.matrix):
|
|
||||||
return x.apply(lambda a: ctx.chop(a, tol))
|
|
||||||
if hasattr(x, "__iter__"):
|
|
||||||
return [ctx.chop(a, tol) for a in x]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
|
|
||||||
r"""
|
|
||||||
Determine whether the difference between `s` and `t` is smaller
|
|
||||||
than a given epsilon, either relatively or absolutely.
|
|
||||||
|
|
||||||
Both a maximum relative difference and a maximum difference
|
|
||||||
('epsilons') may be specified. The absolute difference is
|
|
||||||
defined as `|s-t|` and the relative difference is defined
|
|
||||||
as `|s-t|/\max(|s|, |t|)`.
|
|
||||||
|
|
||||||
If only one epsilon is given, both are set to the same value.
|
|
||||||
If none is given, both epsilons are set to `2^{-p+m}` where
|
|
||||||
`p` is the current working precision and `m` is a small
|
|
||||||
integer. The default setting typically allows :func:`almosteq`
|
|
||||||
to be used to check for mathematical equality
|
|
||||||
in the presence of small rounding errors.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> almosteq(3.141592653589793, 3.141592653589790)
|
|
||||||
True
|
|
||||||
>>> almosteq(3.141592653589793, 3.141592653589700)
|
|
||||||
False
|
|
||||||
>>> almosteq(3.141592653589793, 3.141592653589700, 1e-10)
|
|
||||||
True
|
|
||||||
>>> almosteq(1e-20, 2e-20)
|
|
||||||
True
|
|
||||||
>>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0)
|
|
||||||
False
|
|
||||||
|
|
||||||
"""
|
|
||||||
t = ctx.convert(t)
|
|
||||||
if abs_eps is None and rel_eps is None:
|
|
||||||
rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4)
|
|
||||||
if abs_eps is None:
|
|
||||||
abs_eps = rel_eps
|
|
||||||
elif rel_eps is None:
|
|
||||||
rel_eps = abs_eps
|
|
||||||
diff = abs(s-t)
|
|
||||||
if diff <= abs_eps:
|
|
||||||
return True
|
|
||||||
abss = abs(s)
|
|
||||||
abst = abs(t)
|
|
||||||
if abss < abst:
|
|
||||||
err = diff/abst
|
|
||||||
else:
|
|
||||||
err = diff/abss
|
|
||||||
return err <= rel_eps
|
|
||||||
|
|
||||||
def arange(ctx, *args):
|
|
||||||
r"""
|
|
||||||
This is a generalized version of Python's :func:`range` function
|
|
||||||
that accepts fractional endpoints and step sizes and
|
|
||||||
returns a list of ``mpf`` instances. Like :func:`range`,
|
|
||||||
:func:`arange` can be called with 1, 2 or 3 arguments:
|
|
||||||
|
|
||||||
``arange(b)``
|
|
||||||
`[0, 1, 2, \ldots, x]`
|
|
||||||
``arange(a, b)``
|
|
||||||
`[a, a+1, a+2, \ldots, x]`
|
|
||||||
``arange(a, b, h)``
|
|
||||||
`[a, a+h, a+h, \ldots, x]`
|
|
||||||
|
|
||||||
where `b-1 \le x < b` (in the third case, `b-h \le x < b`).
|
|
||||||
|
|
||||||
Like Python's :func:`range`, the endpoint is not included. To
|
|
||||||
produce ranges where the endpoint is included, :func:`linspace`
|
|
||||||
is more convenient.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> arange(4)
|
|
||||||
[mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')]
|
|
||||||
>>> arange(1, 2, 0.25)
|
|
||||||
[mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')]
|
|
||||||
>>> arange(1, -1, -0.75)
|
|
||||||
[mpf('1.0'), mpf('0.25'), mpf('-0.5')]
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not len(args) <= 3:
|
|
||||||
raise TypeError('arange expected at most 3 arguments, got %i'
|
|
||||||
% len(args))
|
|
||||||
if not len(args) >= 1:
|
|
||||||
raise TypeError('arange expected at least 1 argument, got %i'
|
|
||||||
% len(args))
|
|
||||||
# set default
|
|
||||||
a = 0
|
|
||||||
dt = 1
|
|
||||||
# interpret arguments
|
|
||||||
if len(args) == 1:
|
|
||||||
b = args[0]
|
|
||||||
elif len(args) >= 2:
|
|
||||||
a = args[0]
|
|
||||||
b = args[1]
|
|
||||||
if len(args) == 3:
|
|
||||||
dt = args[2]
|
|
||||||
a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt)
|
|
||||||
assert a + dt != a, 'dt is too small and would cause an infinite loop'
|
|
||||||
# adapt code for sign of dt
|
|
||||||
if a > b:
|
|
||||||
if dt > 0:
|
|
||||||
return []
|
|
||||||
op = gt
|
|
||||||
else:
|
|
||||||
if dt < 0:
|
|
||||||
return []
|
|
||||||
op = lt
|
|
||||||
# create list
|
|
||||||
result = []
|
|
||||||
i = 0
|
|
||||||
t = a
|
|
||||||
while 1:
|
|
||||||
t = a + dt*i
|
|
||||||
i += 1
|
|
||||||
if op(t, b):
|
|
||||||
result.append(t)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return result
|
|
||||||
|
|
||||||
def linspace(ctx, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
``linspace(a, b, n)`` returns a list of `n` evenly spaced
|
|
||||||
samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)``
|
|
||||||
is also valid.
|
|
||||||
|
|
||||||
This function is often more convenient than :func:`arange`
|
|
||||||
for partitioning an interval into subintervals, since
|
|
||||||
the endpoint is included::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> linspace(1, 4, 4)
|
|
||||||
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
|
|
||||||
>>> linspace(mpi(1,4), 4)
|
|
||||||
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
|
|
||||||
|
|
||||||
You may also provide the keyword argument ``endpoint=False``::
|
|
||||||
|
|
||||||
>>> linspace(1, 4, 4, endpoint=False)
|
|
||||||
[mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')]
|
|
||||||
|
|
||||||
"""
|
|
||||||
if len(args) == 3:
|
|
||||||
a = ctx.mpf(args[0])
|
|
||||||
b = ctx.mpf(args[1])
|
|
||||||
n = int(args[2])
|
|
||||||
elif len(args) == 2:
|
|
||||||
assert hasattr(args[0], '_mpi_')
|
|
||||||
a = args[0].a
|
|
||||||
b = args[0].b
|
|
||||||
n = int(args[1])
|
|
||||||
else:
|
|
||||||
raise TypeError('linspace expected 2 or 3 arguments, got %i' \
|
|
||||||
% len(args))
|
|
||||||
if n < 1:
|
|
||||||
raise ValueError('n must be greater than 0')
|
|
||||||
if not 'endpoint' in kwargs or kwargs['endpoint']:
|
|
||||||
if n == 1:
|
|
||||||
return [ctx.mpf(a)]
|
|
||||||
step = (b - a) / ctx.mpf(n - 1)
|
|
||||||
y = [i*step + a for i in xrange(n)]
|
|
||||||
y[-1] = b
|
|
||||||
else:
|
|
||||||
step = (b - a) / ctx.mpf(n)
|
|
||||||
y = [i*step + a for i in xrange(n)]
|
|
||||||
return y
|
|
||||||
|
|
||||||
def cos_sin(ctx, z, **kwargs):
|
|
||||||
return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs)
|
|
||||||
|
|
||||||
def _default_hyper_maxprec(ctx, p):
|
|
||||||
return int(1000 * p**0.25 + 4*p)
|
|
||||||
|
|
||||||
_gcd = staticmethod(libmp.gcd)
|
|
||||||
list_primes = staticmethod(libmp.list_primes)
|
|
||||||
bernfrac = staticmethod(libmp.bernfrac)
|
|
||||||
moebius = staticmethod(libmp.moebius)
|
|
||||||
_ifac = staticmethod(libmp.ifac)
|
|
||||||
_eulernum = staticmethod(libmp.eulernum)
|
|
||||||
|
|
||||||
def sum_accurately(ctx, terms, check_step=1):
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
extraprec = 10
|
|
||||||
while 1:
|
|
||||||
ctx.prec = prec + extraprec + 5
|
|
||||||
max_mag = ctx.ninf
|
|
||||||
s = ctx.zero
|
|
||||||
k = 0
|
|
||||||
for term in terms():
|
|
||||||
s += term
|
|
||||||
if (not k % check_step) and term:
|
|
||||||
term_mag = ctx.mag(term)
|
|
||||||
max_mag = max(max_mag, term_mag)
|
|
||||||
sum_mag = ctx.mag(s)
|
|
||||||
if sum_mag - term_mag > ctx.prec:
|
|
||||||
break
|
|
||||||
k += 1
|
|
||||||
cancellation = max_mag - sum_mag
|
|
||||||
if cancellation != cancellation:
|
|
||||||
break
|
|
||||||
if cancellation < extraprec or ctx._fixed_precision:
|
|
||||||
break
|
|
||||||
extraprec += min(ctx.prec, cancellation)
|
|
||||||
return s
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
|
|
||||||
def power(ctx, x, y):
|
|
||||||
return ctx.convert(x) ** ctx.convert(y)
|
|
||||||
|
|
||||||
def _zeta_int(ctx, n):
|
|
||||||
return ctx.zeta(n)
|
|
||||||
|
|
@ -1,278 +0,0 @@
|
||||||
from ctx_base import StandardBaseContext
|
|
||||||
|
|
||||||
import math
|
|
||||||
import cmath
|
|
||||||
import math2
|
|
||||||
|
|
||||||
import function_docs
|
|
||||||
|
|
||||||
from libmp import mpf_bernoulli, to_float, int_types
|
|
||||||
import libmp
|
|
||||||
|
|
||||||
class FPContext(StandardBaseContext):
|
|
||||||
"""
|
|
||||||
Context for fast low-precision arithmetic (53-bit precision, giving at most
|
|
||||||
about 15-digit accuracy), using Python's builtin float and complex.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(ctx):
|
|
||||||
StandardBaseContext.__init__(ctx)
|
|
||||||
|
|
||||||
# Override SpecialFunctions implementation
|
|
||||||
ctx.loggamma = math2.loggamma
|
|
||||||
ctx._bernoulli_cache = {}
|
|
||||||
ctx.pretty = False
|
|
||||||
|
|
||||||
ctx._init_aliases()
|
|
||||||
|
|
||||||
_mpq = lambda cls, x: float(x[0])/x[1]
|
|
||||||
|
|
||||||
NoConvergence = libmp.NoConvergence
|
|
||||||
|
|
||||||
def _get_prec(ctx): return 53
|
|
||||||
def _set_prec(ctx, p): return
|
|
||||||
def _get_dps(ctx): return 15
|
|
||||||
def _set_dps(ctx, p): return
|
|
||||||
|
|
||||||
_fixed_precision = True
|
|
||||||
|
|
||||||
prec = property(_get_prec, _set_prec)
|
|
||||||
dps = property(_get_dps, _set_dps)
|
|
||||||
|
|
||||||
zero = 0.0
|
|
||||||
one = 1.0
|
|
||||||
eps = math2.EPS
|
|
||||||
inf = math2.INF
|
|
||||||
ninf = math2.NINF
|
|
||||||
nan = math2.NAN
|
|
||||||
j = 1j
|
|
||||||
|
|
||||||
# Called by SpecialFunctions.__init__()
|
|
||||||
@classmethod
|
|
||||||
def _wrap_specfun(cls, name, f, wrap):
|
|
||||||
if wrap:
|
|
||||||
def f_wrapped(ctx, *args, **kwargs):
|
|
||||||
convert = ctx.convert
|
|
||||||
args = [convert(a) for a in args]
|
|
||||||
return f(ctx, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
f_wrapped = f
|
|
||||||
f_wrapped.__doc__ = function_docs.__dict__.get(name, "<no doc>")
|
|
||||||
setattr(cls, name, f_wrapped)
|
|
||||||
|
|
||||||
def bernoulli(ctx, n):
|
|
||||||
cache = ctx._bernoulli_cache
|
|
||||||
if n in cache:
|
|
||||||
return cache[n]
|
|
||||||
cache[n] = to_float(mpf_bernoulli(n, 53, 'n'), strict=True)
|
|
||||||
return cache[n]
|
|
||||||
|
|
||||||
pi = math2.pi
|
|
||||||
e = math2.e
|
|
||||||
euler = math2.euler
|
|
||||||
sqrt2 = 1.4142135623730950488
|
|
||||||
sqrt5 = 2.2360679774997896964
|
|
||||||
phi = 1.6180339887498948482
|
|
||||||
ln2 = 0.69314718055994530942
|
|
||||||
ln10 = 2.302585092994045684
|
|
||||||
euler = 0.57721566490153286061
|
|
||||||
catalan = 0.91596559417721901505
|
|
||||||
khinchin = 2.6854520010653064453
|
|
||||||
apery = 1.2020569031595942854
|
|
||||||
|
|
||||||
absmin = absmax = abs
|
|
||||||
|
|
||||||
def _as_points(ctx, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def fneg(ctx, x, **kwargs):
|
|
||||||
return -ctx.convert(x)
|
|
||||||
|
|
||||||
def fadd(ctx, x, y, **kwargs):
|
|
||||||
return ctx.convert(x)+ctx.convert(y)
|
|
||||||
|
|
||||||
def fsub(ctx, x, y, **kwargs):
|
|
||||||
return ctx.convert(x)-ctx.convert(y)
|
|
||||||
|
|
||||||
def fmul(ctx, x, y, **kwargs):
|
|
||||||
return ctx.convert(x)*ctx.convert(y)
|
|
||||||
|
|
||||||
def fdiv(ctx, x, y, **kwargs):
|
|
||||||
return ctx.convert(x)/ctx.convert(y)
|
|
||||||
|
|
||||||
def fsum(ctx, args, absolute=False, squared=False):
|
|
||||||
if absolute:
|
|
||||||
if squared:
|
|
||||||
return sum((abs(x)**2 for x in args), ctx.zero)
|
|
||||||
return sum((abs(x) for x in args), ctx.zero)
|
|
||||||
if squared:
|
|
||||||
return sum((x**2 for x in args), ctx.zero)
|
|
||||||
return sum(args, ctx.zero)
|
|
||||||
|
|
||||||
def fdot(ctx, xs, ys=None):
|
|
||||||
if ys is not None:
|
|
||||||
xs = zip(xs, ys)
|
|
||||||
return sum((x*y for (x,y) in xs), ctx.zero)
|
|
||||||
|
|
||||||
def is_special(ctx, x):
|
|
||||||
return x - x != 0.0
|
|
||||||
|
|
||||||
def isnan(ctx, x):
|
|
||||||
return x != x
|
|
||||||
|
|
||||||
def isinf(ctx, x):
|
|
||||||
return abs(x) == math2.INF
|
|
||||||
|
|
||||||
def isnpint(ctx, x):
|
|
||||||
if type(x) is complex:
|
|
||||||
if x.imag:
|
|
||||||
return False
|
|
||||||
x = x.real
|
|
||||||
return x <= 0.0 and round(x) == x
|
|
||||||
|
|
||||||
mpf = float
|
|
||||||
mpc = complex
|
|
||||||
|
|
||||||
def convert(ctx, x):
|
|
||||||
try:
|
|
||||||
return float(x)
|
|
||||||
except:
|
|
||||||
return complex(x)
|
|
||||||
|
|
||||||
power = staticmethod(math2.pow)
|
|
||||||
sqrt = staticmethod(math2.sqrt)
|
|
||||||
exp = staticmethod(math2.exp)
|
|
||||||
ln = log = staticmethod(math2.log)
|
|
||||||
cos = staticmethod(math2.cos)
|
|
||||||
sin = staticmethod(math2.sin)
|
|
||||||
tan = staticmethod(math2.tan)
|
|
||||||
cos_sin = staticmethod(math2.cos_sin)
|
|
||||||
acos = staticmethod(math2.acos)
|
|
||||||
asin = staticmethod(math2.asin)
|
|
||||||
atan = staticmethod(math2.atan)
|
|
||||||
cosh = staticmethod(math2.cosh)
|
|
||||||
sinh = staticmethod(math2.sinh)
|
|
||||||
tanh = staticmethod(math2.tanh)
|
|
||||||
gamma = staticmethod(math2.gamma)
|
|
||||||
fac = factorial = staticmethod(math2.factorial)
|
|
||||||
floor = staticmethod(math2.floor)
|
|
||||||
ceil = staticmethod(math2.ceil)
|
|
||||||
cospi = staticmethod(math2.cospi)
|
|
||||||
sinpi = staticmethod(math2.sinpi)
|
|
||||||
cbrt = staticmethod(math2.cbrt)
|
|
||||||
_nthroot = staticmethod(math2.nthroot)
|
|
||||||
_ei = staticmethod(math2.ei)
|
|
||||||
_e1 = staticmethod(math2.e1)
|
|
||||||
_zeta = _zeta_int = staticmethod(math2.zeta)
|
|
||||||
|
|
||||||
# XXX: math2
|
|
||||||
def arg(ctx, z):
|
|
||||||
z = complex(z)
|
|
||||||
return math.atan2(z.imag, z.real)
|
|
||||||
|
|
||||||
def expj(ctx, x):
|
|
||||||
return ctx.exp(ctx.j*x)
|
|
||||||
|
|
||||||
def expjpi(ctx, x):
|
|
||||||
return ctx.exp(ctx.j*ctx.pi*x)
|
|
||||||
|
|
||||||
ldexp = math.ldexp
|
|
||||||
frexp = math.frexp
|
|
||||||
|
|
||||||
def mag(ctx, z):
|
|
||||||
if z:
|
|
||||||
return ctx.frexp(abs(z))[1]
|
|
||||||
return ctx.ninf
|
|
||||||
|
|
||||||
def isint(ctx, z):
|
|
||||||
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
|
|
||||||
if z.imag:
|
|
||||||
return False
|
|
||||||
z = z.real
|
|
||||||
try:
|
|
||||||
return z == int(z)
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def nint_distance(ctx, z):
|
|
||||||
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
|
|
||||||
n = round(z.real)
|
|
||||||
else:
|
|
||||||
n = round(z)
|
|
||||||
if n == z:
|
|
||||||
return n, ctx.ninf
|
|
||||||
return n, ctx.mag(abs(z-n))
|
|
||||||
|
|
||||||
def _convert_param(ctx, z):
|
|
||||||
if type(z) is tuple:
|
|
||||||
p, q = z
|
|
||||||
return ctx.mpf(p) / q, 'R'
|
|
||||||
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
|
|
||||||
intz = int(z.real)
|
|
||||||
else:
|
|
||||||
intz = int(z)
|
|
||||||
if z == intz:
|
|
||||||
return intz, 'Z'
|
|
||||||
return z, 'R'
|
|
||||||
|
|
||||||
def _is_real_type(ctx, z):
|
|
||||||
return isinstance(z, float) or isinstance(z, int_types)
|
|
||||||
|
|
||||||
def _is_complex_type(ctx, z):
|
|
||||||
return isinstance(z, complex)
|
|
||||||
|
|
||||||
def hypsum(ctx, p, q, types, coeffs, z, maxterms=6000, **kwargs):
|
|
||||||
coeffs = list(coeffs)
|
|
||||||
num = range(p)
|
|
||||||
den = range(p,p+q)
|
|
||||||
tol = ctx.eps
|
|
||||||
s = t = 1.0
|
|
||||||
k = 0
|
|
||||||
while 1:
|
|
||||||
for i in num: t *= (coeffs[i]+k)
|
|
||||||
for i in den: t /= (coeffs[i]+k)
|
|
||||||
k += 1; t /= k; t *= z; s += t
|
|
||||||
if abs(t) < tol:
|
|
||||||
return s
|
|
||||||
if k > maxterms:
|
|
||||||
raise ctx.NoConvergence
|
|
||||||
|
|
||||||
def atan2(ctx, x, y):
|
|
||||||
return math.atan2(x, y)
|
|
||||||
|
|
||||||
def psi(ctx, m, z):
|
|
||||||
m = int(m)
|
|
||||||
if m == 0:
|
|
||||||
return ctx.digamma(z)
|
|
||||||
return (-1)**(m+1) * ctx.fac(m) * ctx.zeta(m+1, z)
|
|
||||||
|
|
||||||
digamma = staticmethod(math2.digamma)
|
|
||||||
|
|
||||||
def harmonic(ctx, x):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if x == 0 or x == 1:
|
|
||||||
return x
|
|
||||||
return ctx.digamma(x+1) + ctx.euler
|
|
||||||
|
|
||||||
nstr = str
|
|
||||||
|
|
||||||
def to_fixed(ctx, x, prec):
|
|
||||||
return int(math.ldexp(x, prec))
|
|
||||||
|
|
||||||
def rand(ctx):
|
|
||||||
import random
|
|
||||||
return random.random()
|
|
||||||
|
|
||||||
_erf = staticmethod(math2.erf)
|
|
||||||
_erfc = staticmethod(math2.erfc)
|
|
||||||
|
|
||||||
def sum_accurately(ctx, terms, check_step=1):
|
|
||||||
s = ctx.zero
|
|
||||||
k = 0
|
|
||||||
for term in terms():
|
|
||||||
s += term
|
|
||||||
if (not k % check_step) and term:
|
|
||||||
if abs(term) <= 1e-18*abs(s):
|
|
||||||
break
|
|
||||||
k += 1
|
|
||||||
return s
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,986 +0,0 @@
|
||||||
#from ctx_base import StandardBaseContext
|
|
||||||
|
|
||||||
from libmp import (MPZ, MPZ_ZERO, MPZ_ONE, int_types, repr_dps,
|
|
||||||
round_floor, round_ceiling, dps_to_prec, round_nearest, prec_to_dps,
|
|
||||||
ComplexResult, to_pickable, from_pickable, normalize,
|
|
||||||
from_int, from_float, from_str, to_int, to_float, to_str,
|
|
||||||
from_rational, from_man_exp,
|
|
||||||
fone, fzero, finf, fninf, fnan,
|
|
||||||
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul, mpf_mul_int,
|
|
||||||
mpf_div, mpf_rdiv_int, mpf_pow_int, mpf_mod,
|
|
||||||
mpf_eq, mpf_cmp, mpf_lt, mpf_gt, mpf_le, mpf_ge,
|
|
||||||
mpf_hash, mpf_rand,
|
|
||||||
mpf_sum,
|
|
||||||
bitcount, to_fixed,
|
|
||||||
mpc_to_str,
|
|
||||||
mpc_to_complex, mpc_hash, mpc_pos, mpc_is_nonzero, mpc_neg, mpc_conjugate,
|
|
||||||
mpc_abs, mpc_add, mpc_add_mpf, mpc_sub, mpc_sub_mpf, mpc_mul, mpc_mul_mpf,
|
|
||||||
mpc_mul_int, mpc_div, mpc_div_mpf, mpc_pow, mpc_pow_mpf, mpc_pow_int,
|
|
||||||
mpc_mpf_div,
|
|
||||||
mpf_pow,
|
|
||||||
mpi_mid, mpi_delta, mpi_str,
|
|
||||||
mpi_abs, mpi_pos, mpi_neg, mpi_add, mpi_sub,
|
|
||||||
mpi_mul, mpi_div, mpi_pow_int, mpi_pow,
|
|
||||||
mpf_pi, mpf_degree, mpf_e, mpf_phi, mpf_ln2, mpf_ln10,
|
|
||||||
mpf_euler, mpf_catalan, mpf_apery, mpf_khinchin,
|
|
||||||
mpf_glaisher, mpf_twinprime, mpf_mertens,
|
|
||||||
int_types)
|
|
||||||
|
|
||||||
import rational
|
|
||||||
import function_docs
|
|
||||||
|
|
||||||
new = object.__new__
|
|
||||||
|
|
||||||
class mpnumeric(object):
|
|
||||||
"""Base class for mpf and mpc."""
|
|
||||||
__slots__ = []
|
|
||||||
def __new__(cls, val):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class _mpf(mpnumeric):
|
|
||||||
"""
|
|
||||||
An mpf instance holds a real-valued floating-point number. mpf:s
|
|
||||||
work analogously to Python floats, but support arbitrary-precision
|
|
||||||
arithmetic.
|
|
||||||
"""
|
|
||||||
__slots__ = ['_mpf_']
|
|
||||||
|
|
||||||
def __new__(cls, val=fzero, **kwargs):
|
|
||||||
"""A new mpf can be created from a Python float, an int, a
|
|
||||||
or a decimal string representing a number in floating-point
|
|
||||||
format."""
|
|
||||||
prec, rounding = cls.context._prec_rounding
|
|
||||||
if kwargs:
|
|
||||||
prec = kwargs.get('prec', prec)
|
|
||||||
if 'dps' in kwargs:
|
|
||||||
prec = dps_to_prec(kwargs['dps'])
|
|
||||||
rounding = kwargs.get('rounding', rounding)
|
|
||||||
if type(val) is cls:
|
|
||||||
sign, man, exp, bc = val._mpf_
|
|
||||||
if (not man) and exp:
|
|
||||||
return val
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = normalize(sign, man, exp, bc, prec, rounding)
|
|
||||||
return v
|
|
||||||
elif type(val) is tuple:
|
|
||||||
if len(val) == 2:
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = from_man_exp(val[0], val[1], prec, rounding)
|
|
||||||
return v
|
|
||||||
if len(val) == 4:
|
|
||||||
sign, man, exp, bc = val
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = normalize(sign, MPZ(man), exp, bc, prec, rounding)
|
|
||||||
return v
|
|
||||||
raise ValueError
|
|
||||||
else:
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = mpf_pos(cls.mpf_convert_arg(val, prec, rounding), prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def mpf_convert_arg(cls, x, prec, rounding):
|
|
||||||
if isinstance(x, int_types): return from_int(x)
|
|
||||||
if isinstance(x, float): return from_float(x)
|
|
||||||
if isinstance(x, basestring): return from_str(x, prec, rounding)
|
|
||||||
if isinstance(x, cls.context.constant): return x.func(prec, rounding)
|
|
||||||
if hasattr(x, '_mpf_'): return x._mpf_
|
|
||||||
if hasattr(x, '_mpmath_'):
|
|
||||||
t = cls.context.convert(x._mpmath_(prec, rounding))
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
return t._mpf_
|
|
||||||
raise TypeError("cannot create mpf from " + repr(x))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def mpf_convert_rhs(cls, x):
|
|
||||||
if isinstance(x, int_types): return from_int(x)
|
|
||||||
if isinstance(x, float): return from_float(x)
|
|
||||||
if isinstance(x, complex_types): return cls.context.mpc(x)
|
|
||||||
if isinstance(x, rational.mpq):
|
|
||||||
p, q = x
|
|
||||||
return from_rational(p, q, cls.context.prec)
|
|
||||||
if hasattr(x, '_mpf_'): return x._mpf_
|
|
||||||
if hasattr(x, '_mpmath_'):
|
|
||||||
t = cls.context.convert(x._mpmath_(*cls.context._prec_rounding))
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
return t._mpf_
|
|
||||||
return t
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def mpf_convert_lhs(cls, x):
|
|
||||||
x = cls.mpf_convert_rhs(x)
|
|
||||||
if type(x) is tuple:
|
|
||||||
return cls.context.make_mpf(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
man_exp = property(lambda self: self._mpf_[1:3])
|
|
||||||
man = property(lambda self: self._mpf_[1])
|
|
||||||
exp = property(lambda self: self._mpf_[2])
|
|
||||||
bc = property(lambda self: self._mpf_[3])
|
|
||||||
|
|
||||||
real = property(lambda self: self)
|
|
||||||
imag = property(lambda self: self.context.zero)
|
|
||||||
|
|
||||||
conjugate = lambda self: self
|
|
||||||
|
|
||||||
def __getstate__(self): return to_pickable(self._mpf_)
|
|
||||||
def __setstate__(self, val): self._mpf_ = from_pickable(val)
|
|
||||||
|
|
||||||
def __repr__(s):
|
|
||||||
if s.context.pretty:
|
|
||||||
return str(s)
|
|
||||||
return "mpf('%s')" % to_str(s._mpf_, s.context._repr_digits)
|
|
||||||
|
|
||||||
def __str__(s): return to_str(s._mpf_, s.context._str_digits)
|
|
||||||
def __hash__(s): return mpf_hash(s._mpf_)
|
|
||||||
def __int__(s): return int(to_int(s._mpf_))
|
|
||||||
def __long__(s): return long(to_int(s._mpf_))
|
|
||||||
def __float__(s): return to_float(s._mpf_)
|
|
||||||
def __complex__(s): return complex(float(s))
|
|
||||||
def __nonzero__(s): return s._mpf_ != fzero
|
|
||||||
|
|
||||||
def __abs__(s):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = mpf_abs(s._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __pos__(s):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = mpf_pos(s._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __neg__(s):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = mpf_neg(s._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def _cmp(s, t, func):
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
t = t._mpf_
|
|
||||||
else:
|
|
||||||
t = s.mpf_convert_rhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return func(s._mpf_, t)
|
|
||||||
|
|
||||||
def __cmp__(s, t): return s._cmp(t, mpf_cmp)
|
|
||||||
def __lt__(s, t): return s._cmp(t, mpf_lt)
|
|
||||||
def __gt__(s, t): return s._cmp(t, mpf_gt)
|
|
||||||
def __le__(s, t): return s._cmp(t, mpf_le)
|
|
||||||
def __ge__(s, t): return s._cmp(t, mpf_ge)
|
|
||||||
|
|
||||||
def __ne__(s, t):
|
|
||||||
v = s.__eq__(t)
|
|
||||||
if v is NotImplemented:
|
|
||||||
return v
|
|
||||||
return not v
|
|
||||||
|
|
||||||
def __rsub__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if type(t) in int_types:
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = mpf_sub(from_int(t), s._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
t = s.mpf_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t - s
|
|
||||||
|
|
||||||
def __rdiv__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if isinstance(t, int_types):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpf_ = mpf_rdiv_int(t, s._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
t = s.mpf_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t / s
|
|
||||||
|
|
||||||
def __rpow__(s, t):
|
|
||||||
t = s.mpf_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t ** s
|
|
||||||
|
|
||||||
def __rmod__(s, t):
|
|
||||||
t = s.mpf_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t % s
|
|
||||||
|
|
||||||
def sqrt(s):
|
|
||||||
return s.context.sqrt(s)
|
|
||||||
|
|
||||||
def ae(s, t, rel_eps=None, abs_eps=None):
|
|
||||||
return s.context.almosteq(s, t, rel_eps, abs_eps)
|
|
||||||
|
|
||||||
def to_fixed(self, prec):
|
|
||||||
return to_fixed(self._mpf_, prec)
|
|
||||||
|
|
||||||
|
|
||||||
mpf_binary_op = """
|
|
||||||
def %NAME%(self, other):
|
|
||||||
mpf, new, (prec, rounding) = self._ctxdata
|
|
||||||
sval = self._mpf_
|
|
||||||
if hasattr(other, '_mpf_'):
|
|
||||||
tval = other._mpf_
|
|
||||||
%WITH_MPF%
|
|
||||||
ttype = type(other)
|
|
||||||
if ttype in int_types:
|
|
||||||
%WITH_INT%
|
|
||||||
elif ttype is float:
|
|
||||||
tval = from_float(other)
|
|
||||||
%WITH_MPF%
|
|
||||||
elif hasattr(other, '_mpc_'):
|
|
||||||
tval = other._mpc_
|
|
||||||
mpc = type(other)
|
|
||||||
%WITH_MPC%
|
|
||||||
elif ttype is complex:
|
|
||||||
tval = from_float(other.real), from_float(other.imag)
|
|
||||||
mpc = self.context.mpc
|
|
||||||
%WITH_MPC%
|
|
||||||
if isinstance(other, mpnumeric):
|
|
||||||
return NotImplemented
|
|
||||||
try:
|
|
||||||
other = mpf.context.convert(other, strings=False)
|
|
||||||
except TypeError:
|
|
||||||
return NotImplemented
|
|
||||||
return self.%NAME%(other)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return_mpf = "; obj = new(mpf); obj._mpf_ = val; return obj"
|
|
||||||
return_mpc = "; obj = new(mpc); obj._mpc_ = val; return obj"
|
|
||||||
|
|
||||||
mpf_pow_same = """
|
|
||||||
try:
|
|
||||||
val = mpf_pow(sval, tval, prec, rounding) %s
|
|
||||||
except ComplexResult:
|
|
||||||
if mpf.context.trap_complex:
|
|
||||||
raise
|
|
||||||
mpc = mpf.context.mpc
|
|
||||||
val = mpc_pow((sval, fzero), (tval, fzero), prec, rounding) %s
|
|
||||||
""" % (return_mpf, return_mpc)
|
|
||||||
|
|
||||||
def binary_op(name, with_mpf='', with_int='', with_mpc=''):
|
|
||||||
code = mpf_binary_op
|
|
||||||
code = code.replace("%WITH_INT%", with_int)
|
|
||||||
code = code.replace("%WITH_MPC%", with_mpc)
|
|
||||||
code = code.replace("%WITH_MPF%", with_mpf)
|
|
||||||
code = code.replace("%NAME%", name)
|
|
||||||
np = {}
|
|
||||||
exec code in globals(), np
|
|
||||||
return np[name]
|
|
||||||
|
|
||||||
_mpf.__eq__ = binary_op('__eq__',
|
|
||||||
'return mpf_eq(sval, tval)',
|
|
||||||
'return mpf_eq(sval, from_int(other))',
|
|
||||||
'return (tval[1] == fzero) and mpf_eq(tval[0], sval)')
|
|
||||||
|
|
||||||
_mpf.__add__ = binary_op('__add__',
|
|
||||||
'val = mpf_add(sval, tval, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpf_add(sval, from_int(other), prec, rounding)' + return_mpf,
|
|
||||||
'val = mpc_add_mpf(tval, sval, prec, rounding)' + return_mpc)
|
|
||||||
|
|
||||||
_mpf.__sub__ = binary_op('__sub__',
|
|
||||||
'val = mpf_sub(sval, tval, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpf_sub(sval, from_int(other), prec, rounding)' + return_mpf,
|
|
||||||
'val = mpc_sub((sval, fzero), tval, prec, rounding)' + return_mpc)
|
|
||||||
|
|
||||||
_mpf.__mul__ = binary_op('__mul__',
|
|
||||||
'val = mpf_mul(sval, tval, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpf_mul_int(sval, other, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpc_mul_mpf(tval, sval, prec, rounding)' + return_mpc)
|
|
||||||
|
|
||||||
_mpf.__div__ = binary_op('__div__',
|
|
||||||
'val = mpf_div(sval, tval, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpf_div(sval, from_int(other), prec, rounding)' + return_mpf,
|
|
||||||
'val = mpc_mpf_div(sval, tval, prec, rounding)' + return_mpc)
|
|
||||||
|
|
||||||
_mpf.__mod__ = binary_op('__mod__',
|
|
||||||
'val = mpf_mod(sval, tval, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpf_mod(sval, from_int(other), prec, rounding)' + return_mpf,
|
|
||||||
'raise NotImplementedError("complex modulo")')
|
|
||||||
|
|
||||||
_mpf.__pow__ = binary_op('__pow__',
|
|
||||||
mpf_pow_same,
|
|
||||||
'val = mpf_pow_int(sval, other, prec, rounding)' + return_mpf,
|
|
||||||
'val = mpc_pow((sval, fzero), tval, prec, rounding)' + return_mpc)
|
|
||||||
|
|
||||||
_mpf.__radd__ = _mpf.__add__
|
|
||||||
_mpf.__rmul__ = _mpf.__mul__
|
|
||||||
_mpf.__truediv__ = _mpf.__div__
|
|
||||||
_mpf.__rtruediv__ = _mpf.__rdiv__
|
|
||||||
|
|
||||||
|
|
||||||
class _constant(_mpf):
|
|
||||||
"""Represents a mathematical constant with dynamic precision.
|
|
||||||
When printed or used in an arithmetic operation, a constant
|
|
||||||
is converted to a regular mpf at the working precision. A
|
|
||||||
regular mpf can also be obtained using the operation +x."""
|
|
||||||
|
|
||||||
def __new__(cls, func, name, docname=''):
|
|
||||||
a = object.__new__(cls)
|
|
||||||
a.name = name
|
|
||||||
a.func = func
|
|
||||||
a.__doc__ = getattr(function_docs, docname, '')
|
|
||||||
return a
|
|
||||||
|
|
||||||
def __call__(self, prec=None, dps=None, rounding=None):
|
|
||||||
prec2, rounding2 = self.context._prec_rounding
|
|
||||||
if not prec: prec = prec2
|
|
||||||
if not rounding: rounding = rounding2
|
|
||||||
if dps: prec = dps_to_prec(dps)
|
|
||||||
return self.context.make_mpf(self.func(prec, rounding))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _mpf_(self):
|
|
||||||
prec, rounding = self.context._prec_rounding
|
|
||||||
return self.func(prec, rounding)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<%s: %s~>" % (self.name, self.context.nstr(self))
|
|
||||||
|
|
||||||
|
|
||||||
class _mpc(mpnumeric):
|
|
||||||
"""
|
|
||||||
An mpc represents a complex number using a pair of mpf:s (one
|
|
||||||
for the real part and another for the imaginary part.) The mpc
|
|
||||||
class behaves fairly similarly to Python's complex type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ['_mpc_']
|
|
||||||
|
|
||||||
def __new__(cls, real=0, imag=0):
|
|
||||||
s = object.__new__(cls)
|
|
||||||
if isinstance(real, complex_types):
|
|
||||||
real, imag = real.real, real.imag
|
|
||||||
elif hasattr(real, '_mpc_'):
|
|
||||||
s._mpc_ = real._mpc_
|
|
||||||
return s
|
|
||||||
real = cls.context.mpf(real)
|
|
||||||
imag = cls.context.mpf(imag)
|
|
||||||
s._mpc_ = (real._mpf_, imag._mpf_)
|
|
||||||
return s
|
|
||||||
|
|
||||||
real = property(lambda self: self.context.make_mpf(self._mpc_[0]))
|
|
||||||
imag = property(lambda self: self.context.make_mpf(self._mpc_[1]))
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
return to_pickable(self._mpc_[0]), to_pickable(self._mpc_[1])
|
|
||||||
|
|
||||||
def __setstate__(self, val):
|
|
||||||
self._mpc_ = from_pickable(val[0]), from_pickable(val[1])
|
|
||||||
|
|
||||||
def __repr__(s):
|
|
||||||
if s.context.pretty:
|
|
||||||
return str(s)
|
|
||||||
r = repr(s.real)[4:-1]
|
|
||||||
i = repr(s.imag)[4:-1]
|
|
||||||
return "%s(real=%s, imag=%s)" % (type(s).__name__, r, i)
|
|
||||||
|
|
||||||
def __str__(s):
|
|
||||||
return "(%s)" % mpc_to_str(s._mpc_, s.context._str_digits)
|
|
||||||
|
|
||||||
def __complex__(s):
|
|
||||||
return mpc_to_complex(s._mpc_)
|
|
||||||
|
|
||||||
def __pos__(s):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_pos(s._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __abs__(s):
|
|
||||||
prec, rounding = s.context._prec_rounding
|
|
||||||
v = new(s.context.mpf)
|
|
||||||
v._mpf_ = mpc_abs(s._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __neg__(s):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_neg(s._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def conjugate(s):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_conjugate(s._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __nonzero__(s):
|
|
||||||
return mpc_is_nonzero(s._mpc_)
|
|
||||||
|
|
||||||
def __hash__(s):
|
|
||||||
return mpc_hash(s._mpc_)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def mpc_convert_lhs(cls, x):
|
|
||||||
try:
|
|
||||||
y = cls.context.convert(x)
|
|
||||||
return y
|
|
||||||
except TypeError:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __eq__(s, t):
|
|
||||||
if not hasattr(t, '_mpc_'):
|
|
||||||
if isinstance(t, str):
|
|
||||||
return False
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return s.real == t.real and s.imag == t.imag
|
|
||||||
|
|
||||||
def __ne__(s, t):
|
|
||||||
b = s.__eq__(t)
|
|
||||||
if b is NotImplemented:
|
|
||||||
return b
|
|
||||||
return not b
|
|
||||||
|
|
||||||
def _compare(*args):
|
|
||||||
raise TypeError("no ordering relation is defined for complex numbers")
|
|
||||||
|
|
||||||
__gt__ = _compare
|
|
||||||
__le__ = _compare
|
|
||||||
__gt__ = _compare
|
|
||||||
__ge__ = _compare
|
|
||||||
|
|
||||||
def __add__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if not hasattr(t, '_mpc_'):
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_add_mpf(s._mpc_, t._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_add(s._mpc_, t._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __sub__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if not hasattr(t, '_mpc_'):
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_sub_mpf(s._mpc_, t._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_sub(s._mpc_, t._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __mul__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if not hasattr(t, '_mpc_'):
|
|
||||||
if isinstance(t, int_types):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_mul_int(s._mpc_, t, prec, rounding)
|
|
||||||
return v
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_mul_mpf(s._mpc_, t._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_mul(s._mpc_, t._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __div__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if not hasattr(t, '_mpc_'):
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_div_mpf(s._mpc_, t._mpf_, prec, rounding)
|
|
||||||
return v
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_div(s._mpc_, t._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __pow__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if isinstance(t, int_types):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_pow_int(s._mpc_, t, prec, rounding)
|
|
||||||
return v
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
v = new(cls)
|
|
||||||
if hasattr(t, '_mpf_'):
|
|
||||||
v._mpc_ = mpc_pow_mpf(s._mpc_, t._mpf_, prec, rounding)
|
|
||||||
else:
|
|
||||||
v._mpc_ = mpc_pow(s._mpc_, t._mpc_, prec, rounding)
|
|
||||||
return v
|
|
||||||
|
|
||||||
__radd__ = __add__
|
|
||||||
|
|
||||||
def __rsub__(s, t):
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t - s
|
|
||||||
|
|
||||||
def __rmul__(s, t):
|
|
||||||
cls, new, (prec, rounding) = s._ctxdata
|
|
||||||
if isinstance(t, int_types):
|
|
||||||
v = new(cls)
|
|
||||||
v._mpc_ = mpc_mul_int(s._mpc_, t, prec, rounding)
|
|
||||||
return v
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t * s
|
|
||||||
|
|
||||||
def __rdiv__(s, t):
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t / s
|
|
||||||
|
|
||||||
def __rpow__(s, t):
|
|
||||||
t = s.mpc_convert_lhs(t)
|
|
||||||
if t is NotImplemented:
|
|
||||||
return t
|
|
||||||
return t ** s
|
|
||||||
|
|
||||||
__truediv__ = __div__
|
|
||||||
__rtruediv__ = __rdiv__
|
|
||||||
|
|
||||||
def ae(s, t, rel_eps=None, abs_eps=None):
|
|
||||||
return s.context.almosteq(s, t, rel_eps, abs_eps)
|
|
||||||
|
|
||||||
|
|
||||||
complex_types = (complex, _mpc)
|
|
||||||
|
|
||||||
|
|
||||||
class PythonMPContext:
|
|
||||||
|
|
||||||
def __init__(ctx):
|
|
||||||
ctx._prec_rounding = [53, round_nearest]
|
|
||||||
ctx.mpf = type('mpf', (_mpf,), {})
|
|
||||||
ctx.mpc = type('mpc', (_mpc,), {})
|
|
||||||
ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
|
|
||||||
ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding]
|
|
||||||
ctx.mpf.context = ctx
|
|
||||||
ctx.mpc.context = ctx
|
|
||||||
ctx.constant = type('constant', (_constant,), {})
|
|
||||||
ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
|
|
||||||
ctx.constant.context = ctx
|
|
||||||
|
|
||||||
def make_mpf(ctx, v):
|
|
||||||
a = new(ctx.mpf)
|
|
||||||
a._mpf_ = v
|
|
||||||
return a
|
|
||||||
|
|
||||||
def make_mpc(ctx, v):
|
|
||||||
a = new(ctx.mpc)
|
|
||||||
a._mpc_ = v
|
|
||||||
return a
|
|
||||||
|
|
||||||
def default(ctx):
|
|
||||||
ctx._prec = ctx._prec_rounding[0] = 53
|
|
||||||
ctx._dps = 15
|
|
||||||
ctx.trap_complex = False
|
|
||||||
|
|
||||||
def _set_prec(ctx, n):
|
|
||||||
ctx._prec = ctx._prec_rounding[0] = max(1, int(n))
|
|
||||||
ctx._dps = prec_to_dps(n)
|
|
||||||
|
|
||||||
def _set_dps(ctx, n):
|
|
||||||
ctx._prec = ctx._prec_rounding[0] = dps_to_prec(n)
|
|
||||||
ctx._dps = max(1, int(n))
|
|
||||||
|
|
||||||
prec = property(lambda ctx: ctx._prec, _set_prec)
|
|
||||||
dps = property(lambda ctx: ctx._dps, _set_dps)
|
|
||||||
|
|
||||||
def convert(ctx, x, strings=True):
|
|
||||||
"""
|
|
||||||
Converts *x* to an ``mpf``, ``mpc`` or ``mpi``. If *x* is of type ``mpf``,
|
|
||||||
``mpc``, ``int``, ``float``, ``complex``, the conversion
|
|
||||||
will be performed losslessly.
|
|
||||||
|
|
||||||
If *x* is a string, the result will be rounded to the present
|
|
||||||
working precision. Strings representing fractions or complex
|
|
||||||
numbers are permitted.
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> mpmathify(3.5)
|
|
||||||
mpf('3.5')
|
|
||||||
>>> mpmathify('2.1')
|
|
||||||
mpf('2.1000000000000001')
|
|
||||||
>>> mpmathify('3/4')
|
|
||||||
mpf('0.75')
|
|
||||||
>>> mpmathify('2+3j')
|
|
||||||
mpc(real='2.0', imag='3.0')
|
|
||||||
|
|
||||||
"""
|
|
||||||
if type(x) in ctx.types: return x
|
|
||||||
if isinstance(x, int_types): return ctx.make_mpf(from_int(x))
|
|
||||||
if isinstance(x, float): return ctx.make_mpf(from_float(x))
|
|
||||||
if isinstance(x, complex):
|
|
||||||
return ctx.make_mpc((from_float(x.real), from_float(x.imag)))
|
|
||||||
prec, rounding = ctx._prec_rounding
|
|
||||||
if isinstance(x, rational.mpq):
|
|
||||||
p, q = x
|
|
||||||
return ctx.make_mpf(from_rational(p, q, prec))
|
|
||||||
if strings and isinstance(x, basestring):
|
|
||||||
try:
|
|
||||||
_mpf_ = from_str(x, prec, rounding)
|
|
||||||
return ctx.make_mpf(_mpf_)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
if hasattr(x, '_mpf_'): return ctx.make_mpf(x._mpf_)
|
|
||||||
if hasattr(x, '_mpc_'): return ctx.make_mpc(x._mpc_)
|
|
||||||
if hasattr(x, '_mpmath_'):
|
|
||||||
return ctx.convert(x._mpmath_(prec, rounding))
|
|
||||||
return ctx._convert_fallback(x, strings)
|
|
||||||
|
|
||||||
def isnan(ctx, x):
|
|
||||||
"""
|
|
||||||
For an ``mpf`` *x*, determines whether *x* is not-a-number (nan)::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> isnan(nan), isnan(3)
|
|
||||||
(True, False)
|
|
||||||
"""
|
|
||||||
if not hasattr(x, '_mpf_'):
|
|
||||||
return False
|
|
||||||
return x._mpf_ == fnan
|
|
||||||
|
|
||||||
def isinf(ctx, x):
|
|
||||||
"""
|
|
||||||
For an ``mpf`` *x*, determines whether *x* is infinite::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> isinf(inf), isinf(-inf), isinf(3)
|
|
||||||
(True, True, False)
|
|
||||||
"""
|
|
||||||
if not hasattr(x, '_mpf_'):
|
|
||||||
return False
|
|
||||||
return x._mpf_ in (finf, fninf)
|
|
||||||
|
|
||||||
def isint(ctx, x):
|
|
||||||
"""
|
|
||||||
For an ``mpf`` *x*, or any type that can be converted
|
|
||||||
to ``mpf``, determines whether *x* is exactly
|
|
||||||
integer-valued::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> isint(3), isint(mpf(3)), isint(3.2)
|
|
||||||
(True, True, False)
|
|
||||||
"""
|
|
||||||
if isinstance(x, int_types):
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
x = ctx.convert(x)
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
if hasattr(x, '_mpf_'):
|
|
||||||
if ctx.isnan(x) or ctx.isinf(x):
|
|
||||||
return False
|
|
||||||
return x == int(x)
|
|
||||||
if isinstance(x, ctx.mpq):
|
|
||||||
p, q = x
|
|
||||||
return not (p % q)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def fsum(ctx, terms, absolute=False, squared=False):
|
|
||||||
"""
|
|
||||||
Calculates a sum containing a finite number of terms (for infinite
|
|
||||||
series, see :func:`nsum`). The terms will be converted to
|
|
||||||
mpmath numbers. For len(terms) > 2, this function is generally
|
|
||||||
faster and produces more accurate results than the builtin
|
|
||||||
Python function :func:`sum`.
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> fsum([1, 2, 0.5, 7])
|
|
||||||
mpf('10.5')
|
|
||||||
|
|
||||||
With squared=True each term is squared, and with absolute=True
|
|
||||||
the absolute value of each term is used.
|
|
||||||
"""
|
|
||||||
prec, rnd = ctx._prec_rounding
|
|
||||||
real = []
|
|
||||||
imag = []
|
|
||||||
other = 0
|
|
||||||
for term in terms:
|
|
||||||
reval = imval = 0
|
|
||||||
if hasattr(term, "_mpf_"):
|
|
||||||
reval = term._mpf_
|
|
||||||
elif hasattr(term, "_mpc_"):
|
|
||||||
reval, imval = term._mpc_
|
|
||||||
else:
|
|
||||||
term = ctx.convert(term)
|
|
||||||
if hasattr(term, "_mpf_"):
|
|
||||||
reval = term._mpf_
|
|
||||||
elif hasattr(term, "_mpc_"):
|
|
||||||
reval, imval = term._mpc_
|
|
||||||
else:
|
|
||||||
if absolute: term = ctx.absmax(term)
|
|
||||||
if squared: term = term**2
|
|
||||||
other += term
|
|
||||||
continue
|
|
||||||
if imval:
|
|
||||||
if squared:
|
|
||||||
if absolute:
|
|
||||||
real.append(mpf_mul(reval,reval))
|
|
||||||
real.append(mpf_mul(imval,imval))
|
|
||||||
else:
|
|
||||||
reval, imval = mpc_pow_int((reval,imval),2,prec+10)
|
|
||||||
real.append(reval)
|
|
||||||
imag.append(imval)
|
|
||||||
elif absolute:
|
|
||||||
real.append(mpc_abs((reval,imval), prec))
|
|
||||||
else:
|
|
||||||
real.append(reval)
|
|
||||||
imag.append(imval)
|
|
||||||
else:
|
|
||||||
if squared:
|
|
||||||
reval = mpf_mul(reval, reval)
|
|
||||||
elif absolute:
|
|
||||||
reval = mpf_abs(reval)
|
|
||||||
real.append(reval)
|
|
||||||
s = mpf_sum(real, prec, rnd, absolute)
|
|
||||||
if imag:
|
|
||||||
s = ctx.make_mpc((s, mpf_sum(imag, prec, rnd)))
|
|
||||||
else:
|
|
||||||
s = ctx.make_mpf(s)
|
|
||||||
if other is 0:
|
|
||||||
return s
|
|
||||||
else:
|
|
||||||
return s + other
|
|
||||||
|
|
||||||
def fdot(ctx, A, B=None):
|
|
||||||
r"""
|
|
||||||
Computes the dot product of the iterables `A` and `B`,
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\sum_{k=0} A_k B_k.
|
|
||||||
|
|
||||||
Alternatively, :func:`fdot` accepts a single iterable of pairs.
|
|
||||||
In other words, ``fdot(A,B)`` and ``fdot(zip(A,B))`` are equivalent.
|
|
||||||
|
|
||||||
The elements are automatically converted to mpmath numbers.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> A = [2, 1.5, 3]
|
|
||||||
>>> B = [1, -1, 2]
|
|
||||||
>>> fdot(A, B)
|
|
||||||
mpf('6.5')
|
|
||||||
>>> zip(A, B)
|
|
||||||
[(2, 1), (1.5, -1), (3, 2)]
|
|
||||||
>>> fdot(_)
|
|
||||||
mpf('6.5')
|
|
||||||
|
|
||||||
"""
|
|
||||||
if B:
|
|
||||||
A = zip(A, B)
|
|
||||||
prec, rnd = ctx._prec_rounding
|
|
||||||
real = []
|
|
||||||
imag = []
|
|
||||||
other = 0
|
|
||||||
hasattr_ = hasattr
|
|
||||||
types = (ctx.mpf, ctx.mpc)
|
|
||||||
for a, b in A:
|
|
||||||
if type(a) not in types: a = ctx.convert(a)
|
|
||||||
if type(b) not in types: b = ctx.convert(b)
|
|
||||||
a_real = hasattr_(a, "_mpf_")
|
|
||||||
b_real = hasattr_(b, "_mpf_")
|
|
||||||
if a_real and b_real:
|
|
||||||
real.append(mpf_mul(a._mpf_, b._mpf_))
|
|
||||||
continue
|
|
||||||
a_complex = hasattr_(a, "_mpc_")
|
|
||||||
b_complex = hasattr_(b, "_mpc_")
|
|
||||||
if a_real and b_complex:
|
|
||||||
aval = a._mpf_
|
|
||||||
bre, bim = b._mpc_
|
|
||||||
real.append(mpf_mul(aval, bre))
|
|
||||||
imag.append(mpf_mul(aval, bim))
|
|
||||||
elif b_real and a_complex:
|
|
||||||
are, aim = a._mpc_
|
|
||||||
bval = b._mpf_
|
|
||||||
real.append(mpf_mul(are, bval))
|
|
||||||
imag.append(mpf_mul(aim, bval))
|
|
||||||
elif a_complex and b_complex:
|
|
||||||
re, im = mpc_mul(a._mpc_, b._mpc_, prec+20)
|
|
||||||
real.append(re)
|
|
||||||
imag.append(im)
|
|
||||||
else:
|
|
||||||
other += a*b
|
|
||||||
s = mpf_sum(real, prec, rnd)
|
|
||||||
if imag:
|
|
||||||
s = ctx.make_mpc((s, mpf_sum(imag, prec, rnd)))
|
|
||||||
else:
|
|
||||||
s = ctx.make_mpf(s)
|
|
||||||
if other is 0:
|
|
||||||
return s
|
|
||||||
else:
|
|
||||||
return s + other
|
|
||||||
|
|
||||||
def _wrap_libmp_function(ctx, mpf_f, mpc_f=None, mpi_f=None, doc="<no doc>"):
|
|
||||||
"""
|
|
||||||
Given a low-level mpf_ function, and optionally similar functions
|
|
||||||
for mpc_ and mpi_, defines the function as a context method.
|
|
||||||
|
|
||||||
It is assumed that the return type is the same as that of
|
|
||||||
the input; the exception is that propagation from mpf to mpc is possible
|
|
||||||
by raising ComplexResult.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def f(x, **kwargs):
|
|
||||||
if type(x) not in ctx.types:
|
|
||||||
x = ctx.convert(x)
|
|
||||||
prec, rounding = ctx._prec_rounding
|
|
||||||
if kwargs:
|
|
||||||
prec = kwargs.get('prec', prec)
|
|
||||||
if 'dps' in kwargs:
|
|
||||||
prec = dps_to_prec(kwargs['dps'])
|
|
||||||
rounding = kwargs.get('rounding', rounding)
|
|
||||||
if hasattr(x, '_mpf_'):
|
|
||||||
try:
|
|
||||||
return ctx.make_mpf(mpf_f(x._mpf_, prec, rounding))
|
|
||||||
except ComplexResult:
|
|
||||||
# Handle propagation to complex
|
|
||||||
if ctx.trap_complex:
|
|
||||||
raise
|
|
||||||
return ctx.make_mpc(mpc_f((x._mpf_, fzero), prec, rounding))
|
|
||||||
elif hasattr(x, '_mpc_'):
|
|
||||||
return ctx.make_mpc(mpc_f(x._mpc_, prec, rounding))
|
|
||||||
elif hasattr(x, '_mpi_'):
|
|
||||||
if mpi_f:
|
|
||||||
return ctx.make_mpi(mpi_f(x._mpi_, prec))
|
|
||||||
raise NotImplementedError("%s of a %s" % (name, type(x)))
|
|
||||||
name = mpf_f.__name__[4:]
|
|
||||||
f.__doc__ = function_docs.__dict__.get(name, "Computes the %s of x" % doc)
|
|
||||||
return f
|
|
||||||
|
|
||||||
# Called by SpecialFunctions.__init__()
|
|
||||||
@classmethod
|
|
||||||
def _wrap_specfun(cls, name, f, wrap):
|
|
||||||
if wrap:
|
|
||||||
def f_wrapped(ctx, *args, **kwargs):
|
|
||||||
convert = ctx.convert
|
|
||||||
args = [convert(a) for a in args]
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
retval = f(ctx, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return +retval
|
|
||||||
else:
|
|
||||||
f_wrapped = f
|
|
||||||
f_wrapped.__doc__ = function_docs.__dict__.get(name, "<no doc>")
|
|
||||||
setattr(cls, name, f_wrapped)
|
|
||||||
|
|
||||||
def _convert_param(ctx, x):
|
|
||||||
if hasattr(x, "_mpc_"):
|
|
||||||
v, im = x._mpc_
|
|
||||||
if im != fzero:
|
|
||||||
return x, 'C'
|
|
||||||
elif hasattr(x, "_mpf_"):
|
|
||||||
v = x._mpf_
|
|
||||||
else:
|
|
||||||
if type(x) in int_types:
|
|
||||||
return int(x), 'Z'
|
|
||||||
p = None
|
|
||||||
if isinstance(x, tuple):
|
|
||||||
p, q = x
|
|
||||||
elif isinstance(x, basestring) and '/' in x:
|
|
||||||
p, q = x.split('/')
|
|
||||||
p = int(p)
|
|
||||||
q = int(q)
|
|
||||||
if p is not None:
|
|
||||||
if not p % q:
|
|
||||||
return p // q, 'Z'
|
|
||||||
return ctx.mpq((p,q)), 'Q'
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if hasattr(x, "_mpc_"):
|
|
||||||
v, im = x._mpc_
|
|
||||||
if im != fzero:
|
|
||||||
return x, 'C'
|
|
||||||
elif hasattr(x, "_mpf_"):
|
|
||||||
v = x._mpf_
|
|
||||||
else:
|
|
||||||
return x, 'U'
|
|
||||||
sign, man, exp, bc = v
|
|
||||||
if man:
|
|
||||||
if exp >= -4:
|
|
||||||
if sign:
|
|
||||||
man = -man
|
|
||||||
if exp >= 0:
|
|
||||||
return int(man) << exp, 'Z'
|
|
||||||
if exp >= -4:
|
|
||||||
p, q = int(man), (1<<(-exp))
|
|
||||||
return ctx.mpq((p,q)), 'Q'
|
|
||||||
x = ctx.make_mpf(v)
|
|
||||||
return x, 'R'
|
|
||||||
elif not exp:
|
|
||||||
return 0, 'Z'
|
|
||||||
else:
|
|
||||||
return x, 'U'
|
|
||||||
|
|
||||||
def _mpf_mag(ctx, x):
|
|
||||||
sign, man, exp, bc = x
|
|
||||||
if man:
|
|
||||||
return exp+bc
|
|
||||||
if x == fzero:
|
|
||||||
return ctx.ninf
|
|
||||||
if x == finf or x == fninf:
|
|
||||||
return ctx.inf
|
|
||||||
return ctx.nan
|
|
||||||
|
|
||||||
def mag(ctx, x):
|
|
||||||
"""
|
|
||||||
Quick logarithmic magnitude estimate of a number.
|
|
||||||
Returns an integer or infinity `m` such that `|x| <= 2^m`.
|
|
||||||
It is not guaranteed that `m` is an optimal bound,
|
|
||||||
but it will never be off by more than 2 (and probably not
|
|
||||||
more than 1).
|
|
||||||
"""
|
|
||||||
if hasattr(x, "_mpf_"):
|
|
||||||
return ctx._mpf_mag(x._mpf_)
|
|
||||||
elif hasattr(x, "_mpc_"):
|
|
||||||
r, i = x._mpc_
|
|
||||||
if r == fzero:
|
|
||||||
return ctx._mpf_mag(i)
|
|
||||||
if i == fzero:
|
|
||||||
return ctx._mpf_mag(r)
|
|
||||||
return 1+max(ctx._mpf_mag(r), ctx._mpf_mag(i))
|
|
||||||
elif isinstance(x, int_types):
|
|
||||||
if x:
|
|
||||||
return bitcount(abs(x))
|
|
||||||
return ctx.ninf
|
|
||||||
elif isinstance(x, rational.mpq):
|
|
||||||
p, q = x
|
|
||||||
if p:
|
|
||||||
return 1 + bitcount(abs(p)) - bitcount(abs(q))
|
|
||||||
return ctx.ninf
|
|
||||||
else:
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if hasattr(x, "_mpf_") or hasattr(x, "_mpc_"):
|
|
||||||
return ctx.mag(x)
|
|
||||||
else:
|
|
||||||
raise TypeError("requires an mpf/mpc")
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,7 +0,0 @@
|
||||||
import functions
|
|
||||||
# Hack to update methods
|
|
||||||
import factorials
|
|
||||||
import hypergeometric
|
|
||||||
import elliptic
|
|
||||||
import zeta
|
|
||||||
import rszeta
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,196 +0,0 @@
|
||||||
from functions import defun, defun_wrapped
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def gammaprod(ctx, a, b, _infsign=False):
|
|
||||||
a = [ctx.convert(x) for x in a]
|
|
||||||
b = [ctx.convert(x) for x in b]
|
|
||||||
poles_num = []
|
|
||||||
poles_den = []
|
|
||||||
regular_num = []
|
|
||||||
regular_den = []
|
|
||||||
for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
|
|
||||||
for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
|
|
||||||
# One more pole in numerator or denominator gives 0 or inf
|
|
||||||
if len(poles_num) < len(poles_den): return ctx.zero
|
|
||||||
if len(poles_num) > len(poles_den):
|
|
||||||
# Get correct sign of infinity for x+h, h -> 0 from above
|
|
||||||
# XXX: hack, this should be done properly
|
|
||||||
if _infsign:
|
|
||||||
a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
|
|
||||||
b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
|
|
||||||
return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
|
|
||||||
else:
|
|
||||||
return ctx.inf
|
|
||||||
# All poles cancel
|
|
||||||
# lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
|
|
||||||
p = ctx.one
|
|
||||||
orig = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec = orig + 15
|
|
||||||
while poles_num:
|
|
||||||
i = poles_num.pop()
|
|
||||||
j = poles_den.pop()
|
|
||||||
p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
|
|
||||||
for x in regular_num: p *= ctx.gamma(x)
|
|
||||||
for x in regular_den: p /= ctx.gamma(x)
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
return +p
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def beta(ctx, x, y):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
y = ctx.convert(y)
|
|
||||||
if ctx.isinf(y):
|
|
||||||
x, y = y, x
|
|
||||||
if ctx.isinf(x):
|
|
||||||
if x == ctx.inf and not ctx._im(y):
|
|
||||||
if y == ctx.ninf:
|
|
||||||
return ctx.nan
|
|
||||||
if y > 0:
|
|
||||||
return ctx.zero
|
|
||||||
if ctx.isint(y):
|
|
||||||
return ctx.nan
|
|
||||||
if y < 0:
|
|
||||||
return ctx.sign(ctx.gamma(y)) * ctx.inf
|
|
||||||
return ctx.nan
|
|
||||||
return ctx.gammaprod([x, y], [x+y])
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def binomial(ctx, n, k):
|
|
||||||
return ctx.gammaprod([n+1], [k+1, n-k+1])
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def rf(ctx, x, n):
|
|
||||||
return ctx.gammaprod([x+n], [x])
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def ff(ctx, x, n):
|
|
||||||
return ctx.gammaprod([x+1], [x-n+1])
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def fac2(ctx, x):
|
|
||||||
if ctx.isinf(x):
|
|
||||||
if x == ctx.inf:
|
|
||||||
return x
|
|
||||||
return ctx.nan
|
|
||||||
return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def barnesg(ctx, z):
|
|
||||||
if ctx.isinf(z):
|
|
||||||
if z == ctx.inf:
|
|
||||||
return z
|
|
||||||
return ctx.nan
|
|
||||||
if ctx.isnan(z):
|
|
||||||
return z
|
|
||||||
if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
|
|
||||||
return z*0
|
|
||||||
# Account for size (would not be needed if computing log(G))
|
|
||||||
if abs(z) > 5:
|
|
||||||
ctx.dps += 2*ctx.log(abs(z),2)
|
|
||||||
# Reflection formula
|
|
||||||
if ctx.re(z) < -ctx.dps:
|
|
||||||
w = 1-z
|
|
||||||
pi2 = 2*ctx.pi
|
|
||||||
u = ctx.expjpi(2*w)
|
|
||||||
v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
|
|
||||||
ctx.j*ctx.polylog(2, u)/pi2
|
|
||||||
v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
|
|
||||||
if ctx._is_real_type(z):
|
|
||||||
v = ctx._re(v)
|
|
||||||
return v
|
|
||||||
# Estimate terms for asymptotic expansion
|
|
||||||
# TODO: fixme, obviously
|
|
||||||
N = ctx.dps // 2 + 5
|
|
||||||
G = 1
|
|
||||||
while abs(z) < N or ctx.re(z) < 1:
|
|
||||||
G /= ctx.gamma(z)
|
|
||||||
z += 1
|
|
||||||
z -= 1
|
|
||||||
s = ctx.mpf(1)/12
|
|
||||||
s -= ctx.log(ctx.glaisher)
|
|
||||||
s += z*ctx.log(2*ctx.pi)/2
|
|
||||||
s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
|
|
||||||
s -= 3*z**2/4
|
|
||||||
z2k = z2 = z**2
|
|
||||||
for k in xrange(1, N+1):
|
|
||||||
t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
|
|
||||||
if abs(t) < ctx.eps:
|
|
||||||
#print k, N # check how many terms were needed
|
|
||||||
break
|
|
||||||
z2k *= z2
|
|
||||||
s += t
|
|
||||||
#if k == N:
|
|
||||||
# print "warning: series for barnesg failed to converge", ctx.dps
|
|
||||||
return G*ctx.exp(s)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def superfac(ctx, z):
|
|
||||||
return ctx.barnesg(z+2)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def hyperfac(ctx, z):
|
|
||||||
# XXX: estimate needed extra bits accurately
|
|
||||||
if z == ctx.inf:
|
|
||||||
return z
|
|
||||||
if abs(z) > 5:
|
|
||||||
extra = 4*int(ctx.log(abs(z),2))
|
|
||||||
else:
|
|
||||||
extra = 0
|
|
||||||
ctx.prec += extra
|
|
||||||
if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
|
|
||||||
n = int(ctx.re(z))
|
|
||||||
h = ctx.hyperfac(-n-1)
|
|
||||||
if ((n+1)//2) & 1:
|
|
||||||
h = -h
|
|
||||||
if ctx._is_complex_type(z):
|
|
||||||
return h + 0j
|
|
||||||
return h
|
|
||||||
zp1 = z+1
|
|
||||||
# Wrong branch cut
|
|
||||||
#v = ctx.gamma(zp1)**z
|
|
||||||
#ctx.prec -= extra
|
|
||||||
#return v / ctx.barnesg(zp1)
|
|
||||||
v = ctx.exp(z*ctx.loggamma(zp1))
|
|
||||||
ctx.prec -= extra
|
|
||||||
return v / ctx.barnesg(zp1)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def loggamma(ctx, z):
|
|
||||||
a = ctx._re(z)
|
|
||||||
b = ctx._im(z)
|
|
||||||
if not b and a > 0:
|
|
||||||
return ctx.ln(ctx.gamma(z))
|
|
||||||
u = ctx.arg(z)
|
|
||||||
w = ctx.ln(ctx.gamma(z))
|
|
||||||
if b:
|
|
||||||
gi = -b - u/2 + a*u + b*ctx.ln(abs(z))
|
|
||||||
n = ctx.floor((gi-ctx._im(w))/(2*ctx.pi)+0.5) * (2*ctx.pi)
|
|
||||||
return w + n*ctx.j
|
|
||||||
elif a < 0:
|
|
||||||
n = int(ctx.floor(a))
|
|
||||||
w += (n-(n%2))*ctx.pi*ctx.j
|
|
||||||
return w
|
|
||||||
|
|
||||||
'''
|
|
||||||
@defun
|
|
||||||
def psi0(ctx, z):
|
|
||||||
"""Shortcut for psi(0,z) (the digamma function)"""
|
|
||||||
return ctx.psi(0, z)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def psi1(ctx, z):
|
|
||||||
"""Shortcut for psi(1,z) (the trigamma function)"""
|
|
||||||
return ctx.psi(1, z)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def psi2(ctx, z):
|
|
||||||
"""Shortcut for psi(2,z) (the tetragamma function)"""
|
|
||||||
return ctx.psi(2, z)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def psi3(ctx, z):
|
|
||||||
"""Shortcut for psi(3,z) (the pentagamma function)"""
|
|
||||||
return ctx.psi(3, z)
|
|
||||||
'''
|
|
||||||
|
|
@ -1,435 +0,0 @@
|
||||||
class SpecialFunctions(object):
|
|
||||||
"""
|
|
||||||
This class implements special functions using high-level code.
|
|
||||||
|
|
||||||
Elementary and some other functions (e.g. gamma function, basecase
|
|
||||||
hypergeometric series) are assumed to be predefined by the context as
|
|
||||||
"builtins" or "low-level" functions.
|
|
||||||
"""
|
|
||||||
defined_functions = {}
|
|
||||||
|
|
||||||
# The series for the Jacobi theta functions converge for |q| < 1;
|
|
||||||
# in the current implementation they throw a ValueError for
|
|
||||||
# abs(q) > THETA_Q_LIM
|
|
||||||
THETA_Q_LIM = 1 - 10**-7
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
cls = self.__class__
|
|
||||||
for name in cls.defined_functions:
|
|
||||||
f, wrap = cls.defined_functions[name]
|
|
||||||
cls._wrap_specfun(name, f, wrap)
|
|
||||||
|
|
||||||
self.mpq_1 = self._mpq((1,1))
|
|
||||||
self.mpq_0 = self._mpq((0,1))
|
|
||||||
self.mpq_1_2 = self._mpq((1,2))
|
|
||||||
self.mpq_3_2 = self._mpq((3,2))
|
|
||||||
self.mpq_1_4 = self._mpq((1,4))
|
|
||||||
self.mpq_1_16 = self._mpq((1,16))
|
|
||||||
self.mpq_3_16 = self._mpq((3,16))
|
|
||||||
self.mpq_5_2 = self._mpq((5,2))
|
|
||||||
self.mpq_3_4 = self._mpq((3,4))
|
|
||||||
self.mpq_7_4 = self._mpq((7,4))
|
|
||||||
self.mpq_5_4 = self._mpq((5,4))
|
|
||||||
|
|
||||||
self._aliases.update({
|
|
||||||
'phase' : 'arg',
|
|
||||||
'conjugate' : 'conj',
|
|
||||||
'nthroot' : 'root',
|
|
||||||
'polygamma' : 'psi',
|
|
||||||
'hurwitz' : 'zeta',
|
|
||||||
#'digamma' : 'psi0',
|
|
||||||
#'trigamma' : 'psi1',
|
|
||||||
#'tetragamma' : 'psi2',
|
|
||||||
#'pentagamma' : 'psi3',
|
|
||||||
'fibonacci' : 'fib',
|
|
||||||
'factorial' : 'fac',
|
|
||||||
})
|
|
||||||
|
|
||||||
# Default -- do nothing
|
|
||||||
@classmethod
|
|
||||||
def _wrap_specfun(cls, name, f, wrap):
|
|
||||||
setattr(cls, name, f)
|
|
||||||
|
|
||||||
# Optional fast versions of common functions in common cases.
|
|
||||||
# If not overridden, default (generic hypergeometric series)
|
|
||||||
# implementations will be used
|
|
||||||
def _besselj(ctx, n, z): raise NotImplementedError
|
|
||||||
def _erf(ctx, z): raise NotImplementedError
|
|
||||||
def _erfc(ctx, z): raise NotImplementedError
|
|
||||||
def _gamma_upper_int(ctx, z, a): raise NotImplementedError
|
|
||||||
def _expint_int(ctx, n, z): raise NotImplementedError
|
|
||||||
def _zeta(ctx, s): raise NotImplementedError
|
|
||||||
def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError
|
|
||||||
def _ei(ctx, z): raise NotImplementedError
|
|
||||||
def _e1(ctx, z): raise NotImplementedError
|
|
||||||
def _ci(ctx, z): raise NotImplementedError
|
|
||||||
def _si(ctx, z): raise NotImplementedError
|
|
||||||
def _altzeta(ctx, s): raise NotImplementedError
|
|
||||||
|
|
||||||
def defun_wrapped(f):
|
|
||||||
SpecialFunctions.defined_functions[f.__name__] = f, True
|
|
||||||
|
|
||||||
def defun(f):
|
|
||||||
SpecialFunctions.defined_functions[f.__name__] = f, False
|
|
||||||
|
|
||||||
def defun_static(f):
|
|
||||||
setattr(SpecialFunctions, f.__name__, f)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def cot(ctx, z): return ctx.one / ctx.tan(z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def sec(ctx, z): return ctx.one / ctx.cos(z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def csc(ctx, z): return ctx.one / ctx.sin(z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def coth(ctx, z): return ctx.one / ctx.tanh(z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def sech(ctx, z): return ctx.one / ctx.cosh(z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def csch(ctx, z): return ctx.one / ctx.sinh(z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def acot(ctx, z): return ctx.atan(ctx.one / z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def asec(ctx, z): return ctx.acos(ctx.one / z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def acsc(ctx, z): return ctx.asin(ctx.one / z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def acoth(ctx, z): return ctx.atanh(ctx.one / z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def asech(ctx, z): return ctx.acosh(ctx.one / z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def acsch(ctx, z): return ctx.asinh(ctx.one / z)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def sign(ctx, x):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if not x or ctx.isnan(x):
|
|
||||||
return x
|
|
||||||
if ctx._is_real_type(x):
|
|
||||||
return ctx.mpf(cmp(x, 0))
|
|
||||||
return x / abs(x)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def agm(ctx, a, b=1):
|
|
||||||
if b == 1:
|
|
||||||
return ctx.agm1(a)
|
|
||||||
a = ctx.convert(a)
|
|
||||||
b = ctx.convert(b)
|
|
||||||
return ctx._agm(a, b)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def sinc(ctx, x):
|
|
||||||
if ctx.isinf(x):
|
|
||||||
return 1/x
|
|
||||||
if not x:
|
|
||||||
return x+1
|
|
||||||
return ctx.sin(x)/x
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def sincpi(ctx, x):
|
|
||||||
if ctx.isinf(x):
|
|
||||||
return 1/x
|
|
||||||
if not x:
|
|
||||||
return x+1
|
|
||||||
return ctx.sinpi(x)/(ctx.pi*x)
|
|
||||||
|
|
||||||
# TODO: tests; improve implementation
|
|
||||||
@defun_wrapped
|
|
||||||
def expm1(ctx, x):
|
|
||||||
if not x:
|
|
||||||
return ctx.zero
|
|
||||||
# exp(x) - 1 ~ x
|
|
||||||
if ctx.mag(x) < -ctx.prec:
|
|
||||||
return x + 0.5*x**2
|
|
||||||
# TODO: accurately eval the smaller of the real/imag parts
|
|
||||||
return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def powm1(ctx, x, y):
|
|
||||||
mag = ctx.mag
|
|
||||||
one = ctx.one
|
|
||||||
w = x**y - one
|
|
||||||
M = mag(w)
|
|
||||||
# Only moderate cancellation
|
|
||||||
if M > -8:
|
|
||||||
return w
|
|
||||||
# Check for the only possible exact cases
|
|
||||||
if not w:
|
|
||||||
if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)):
|
|
||||||
return w
|
|
||||||
x1 = x - one
|
|
||||||
magy = mag(y)
|
|
||||||
lnx = ctx.ln(x)
|
|
||||||
# Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2)
|
|
||||||
if magy + mag(lnx) < -ctx.prec:
|
|
||||||
return lnx*y + (lnx*y)**2/2
|
|
||||||
# TODO: accurately eval the smaller of the real/imag part
|
|
||||||
return ctx.sum_accurately(lambda: iter([x**y, -1]), 1)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def _rootof1(ctx, k, n):
|
|
||||||
k = int(k)
|
|
||||||
n = int(n)
|
|
||||||
k %= n
|
|
||||||
if not k:
|
|
||||||
return ctx.one
|
|
||||||
elif 2*k == n:
|
|
||||||
return -ctx.one
|
|
||||||
elif 4*k == n:
|
|
||||||
return ctx.j
|
|
||||||
elif 4*k == 3*n:
|
|
||||||
return -ctx.j
|
|
||||||
return ctx.expjpi(2*ctx.mpf(k)/n)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def root(ctx, x, n, k=0):
|
|
||||||
n = int(n)
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if k:
|
|
||||||
# Special case: there is an exact real root
|
|
||||||
if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0):
|
|
||||||
return -ctx.root(-x, n)
|
|
||||||
# Multiply by root of unity
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
v = ctx.root(x, n, 0) * ctx._rootof1(k, n)
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return +v
|
|
||||||
return ctx._nthroot(x, n)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def unitroots(ctx, n, primitive=False):
|
|
||||||
gcd = ctx._gcd
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
if primitive:
|
|
||||||
v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1]
|
|
||||||
else:
|
|
||||||
# TODO: this can be done *much* faster
|
|
||||||
v = [ctx._rootof1(k,n) for k in range(n)]
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return [+x for x in v]
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def arg(ctx, x):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
re = ctx._re(x)
|
|
||||||
im = ctx._im(x)
|
|
||||||
return ctx.atan2(im, re)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def fabs(ctx, x):
|
|
||||||
return abs(ctx.convert(x))
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def re(ctx, x):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if hasattr(x, "real"): # py2.5 doesn't have .real/.imag for all numbers
|
|
||||||
return x.real
|
|
||||||
return x
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def im(ctx, x):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if hasattr(x, "imag"): # py2.5 doesn't have .real/.imag for all numbers
|
|
||||||
return x.imag
|
|
||||||
return ctx.zero
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def conj(ctx, x):
|
|
||||||
return ctx.convert(x).conjugate()
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def polar(ctx, z):
|
|
||||||
return (ctx.fabs(z), ctx.arg(z))
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def rect(ctx, r, phi):
|
|
||||||
return r * ctx.mpc(*ctx.cos_sin(phi))
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def log(ctx, x, b=None):
|
|
||||||
if b is None:
|
|
||||||
return ctx.ln(x)
|
|
||||||
wp = ctx.prec + 20
|
|
||||||
return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def log10(ctx, x):
|
|
||||||
return ctx.log(x, 10)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def modf(ctx, x, y):
|
|
||||||
return ctx.convert(x) % ctx.convert(y)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def degrees(ctx, x):
|
|
||||||
return x / ctx.degree
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def radians(ctx, x):
|
|
||||||
return x * ctx.degree
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def lambertw(ctx, z, k=0):
|
|
||||||
k = int(k)
|
|
||||||
if ctx.isnan(z):
|
|
||||||
return z
|
|
||||||
ctx.prec += 20
|
|
||||||
mag = ctx.mag(z)
|
|
||||||
# Start from fp approximation
|
|
||||||
if ctx is ctx._mp and abs(mag) < 900 and abs(k) < 10000 and \
|
|
||||||
abs(z+0.36787944117144) > 0.01:
|
|
||||||
w = ctx._fp.lambertw(z, k)
|
|
||||||
else:
|
|
||||||
absz = abs(z)
|
|
||||||
# We must be extremely careful near the singularities at -1/e and 0
|
|
||||||
u = ctx.exp(-1)
|
|
||||||
if absz <= u:
|
|
||||||
if not z:
|
|
||||||
# w(0,0) = 0; for all other branches we hit the pole
|
|
||||||
if not k:
|
|
||||||
return z
|
|
||||||
return ctx.ninf
|
|
||||||
if not k:
|
|
||||||
w = z
|
|
||||||
# For small real z < 0, the -1 branch aves roughly like log(-z)
|
|
||||||
elif k == -1 and not ctx.im(z) and ctx.re(z) < 0:
|
|
||||||
w = ctx.ln(-z)
|
|
||||||
# Use a simple asymptotic approximation.
|
|
||||||
else:
|
|
||||||
w = ctx.ln(z)
|
|
||||||
# The branches are roughly logarithmic. This approximation
|
|
||||||
# gets better for large |k|; need to check that this always
|
|
||||||
# works for k ~= -1, 0, 1.
|
|
||||||
if k: w += k * 2*ctx.pi*ctx.j
|
|
||||||
elif k == 0 and ctx.im(z) and absz <= 0.7:
|
|
||||||
# Both the W(z) ~= z and W(z) ~= ln(z) approximations break
|
|
||||||
# down around z ~= -0.5 (converging to the wrong branch), so patch
|
|
||||||
# with a constant approximation (adjusted for sign)
|
|
||||||
if abs(z+0.5) < 0.1:
|
|
||||||
if ctx.im(z) > 0:
|
|
||||||
w = ctx.mpc(0.7+0.7j)
|
|
||||||
else:
|
|
||||||
w = ctx.mpc(0.7-0.7j)
|
|
||||||
else:
|
|
||||||
w = z
|
|
||||||
else:
|
|
||||||
if z == ctx.inf:
|
|
||||||
if k == 0:
|
|
||||||
return z
|
|
||||||
else:
|
|
||||||
return z + 2*k*ctx.pi*ctx.j
|
|
||||||
if z == ctx.ninf:
|
|
||||||
return (-z) + (2*k+1)*ctx.pi*ctx.j
|
|
||||||
# Simple asymptotic approximation as above
|
|
||||||
w = ctx.ln(z)
|
|
||||||
if k:
|
|
||||||
w += k * 2*ctx.pi*ctx.j
|
|
||||||
# Use Halley iteration to solve w*exp(w) = z
|
|
||||||
two = ctx.mpf(2)
|
|
||||||
weps = ctx.ldexp(ctx.eps, 15)
|
|
||||||
for i in xrange(100):
|
|
||||||
ew = ctx.exp(w)
|
|
||||||
wew = w*ew
|
|
||||||
wewz = wew-z
|
|
||||||
wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
|
|
||||||
if abs(wn-w) < weps*abs(wn):
|
|
||||||
return wn
|
|
||||||
else:
|
|
||||||
w = wn
|
|
||||||
ctx.warn("Lambert W iteration failed to converge for %s" % z)
|
|
||||||
return wn
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def bell(ctx, n, x=1):
|
|
||||||
x = ctx.convert(x)
|
|
||||||
if not n:
|
|
||||||
if ctx.isnan(x):
|
|
||||||
return x
|
|
||||||
return type(x)(1)
|
|
||||||
if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n):
|
|
||||||
return x**n
|
|
||||||
if n == 1: return x
|
|
||||||
if n == 2: return x*(x+1)
|
|
||||||
if x == 0: return ctx.sincpi(n)
|
|
||||||
return _polyexp(ctx, n, x, True) / ctx.exp(x)
|
|
||||||
|
|
||||||
def _polyexp(ctx, n, x, extra=False):
|
|
||||||
def _terms():
|
|
||||||
if extra:
|
|
||||||
yield ctx.sincpi(n)
|
|
||||||
t = x
|
|
||||||
k = 1
|
|
||||||
while 1:
|
|
||||||
yield k**n * t
|
|
||||||
k += 1
|
|
||||||
t = t*x/k
|
|
||||||
return ctx.sum_accurately(_terms, check_step=4)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def polyexp(ctx, s, z):
|
|
||||||
if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s):
|
|
||||||
return z**s
|
|
||||||
if z == 0: return z*s
|
|
||||||
if s == 0: return ctx.expm1(z)
|
|
||||||
if s == 1: return ctx.exp(z)*z
|
|
||||||
if s == 2: return ctx.exp(z)*z*(z+1)
|
|
||||||
return _polyexp(ctx, s, z)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def cyclotomic(ctx, n, z):
|
|
||||||
n = int(n)
|
|
||||||
assert n >= 0
|
|
||||||
p = ctx.one
|
|
||||||
if n == 0:
|
|
||||||
return p
|
|
||||||
if n == 1:
|
|
||||||
return z - p
|
|
||||||
if n == 2:
|
|
||||||
return z + p
|
|
||||||
# Use divisor product representation. Unfortunately, this sometimes
|
|
||||||
# includes singularities for roots of unity, which we have to cancel out.
|
|
||||||
# Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1).
|
|
||||||
a_prod = 1
|
|
||||||
b_prod = 1
|
|
||||||
num_zeros = 0
|
|
||||||
num_poles = 0
|
|
||||||
for d in range(1,n+1):
|
|
||||||
if not n % d:
|
|
||||||
w = ctx.moebius(n//d)
|
|
||||||
# Use powm1 because it is important that we get 0 only
|
|
||||||
# if it really is exactly 0
|
|
||||||
b = -ctx.powm1(z, d)
|
|
||||||
if b:
|
|
||||||
p *= b**w
|
|
||||||
else:
|
|
||||||
if w == 1:
|
|
||||||
a_prod *= d
|
|
||||||
num_zeros += 1
|
|
||||||
elif w == -1:
|
|
||||||
b_prod *= d
|
|
||||||
num_poles += 1
|
|
||||||
#print n, num_zeros, num_poles
|
|
||||||
if num_zeros:
|
|
||||||
if num_zeros > num_poles:
|
|
||||||
p *= 0
|
|
||||||
else:
|
|
||||||
p *= a_prod
|
|
||||||
p /= b_prod
|
|
||||||
return p
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,693 +0,0 @@
|
||||||
from functions import defun, defun_wrapped, defun_static
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def stieltjes(ctx, n, a=1):
|
|
||||||
n = ctx.convert(n)
|
|
||||||
a = ctx.convert(a)
|
|
||||||
if n < 0:
|
|
||||||
return ctx.bad_domain("Stieltjes constants defined for n >= 0")
|
|
||||||
if hasattr(ctx, "stieltjes_cache"):
|
|
||||||
stieltjes_cache = ctx.stieltjes_cache
|
|
||||||
else:
|
|
||||||
stieltjes_cache = ctx.stieltjes_cache = {}
|
|
||||||
if a == 1:
|
|
||||||
if n == 0:
|
|
||||||
return +ctx.euler
|
|
||||||
if n in stieltjes_cache:
|
|
||||||
prec, s = stieltjes_cache[n]
|
|
||||||
if prec >= ctx.prec:
|
|
||||||
return +s
|
|
||||||
mag = 1
|
|
||||||
def f(x):
|
|
||||||
xa = x/a
|
|
||||||
v = (xa-ctx.j)*ctx.ln(a-ctx.j*x)**n/(1+xa**2)/(ctx.exp(2*ctx.pi*x)-1)
|
|
||||||
return ctx._re(v) / mag
|
|
||||||
orig = ctx.prec
|
|
||||||
try:
|
|
||||||
# Normalize integrand by approx. magnitude to
|
|
||||||
# speed up quadrature (which uses absolute error)
|
|
||||||
if n > 50:
|
|
||||||
ctx.prec = 20
|
|
||||||
mag = ctx.quad(f, [0,ctx.inf], maxdegree=3)
|
|
||||||
ctx.prec = orig + 10 + int(n**0.5)
|
|
||||||
s = ctx.quad(f, [0,ctx.inf], maxdegree=20)
|
|
||||||
v = ctx.ln(a)**n/(2*a) - ctx.ln(a)**(n+1)/(n+1) + 2*s/a*mag
|
|
||||||
finally:
|
|
||||||
ctx.prec = orig
|
|
||||||
if a == 1 and ctx.isint(n):
|
|
||||||
stieltjes_cache[n] = (ctx.prec, v)
|
|
||||||
return +v
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def siegeltheta(ctx, t):
|
|
||||||
if ctx._im(t):
|
|
||||||
# XXX: cancellation occurs
|
|
||||||
a = ctx.loggamma(0.25+0.5j*t)
|
|
||||||
b = ctx.loggamma(0.25-0.5j*t)
|
|
||||||
return -ctx.ln(ctx.pi)/2*t - 0.5j*(a-b)
|
|
||||||
else:
|
|
||||||
if ctx.isinf(t):
|
|
||||||
return t
|
|
||||||
return ctx._im(ctx.loggamma(0.25+0.5j*t)) - ctx.ln(ctx.pi)/2*t
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def grampoint(ctx, n):
|
|
||||||
# asymptotic expansion, from
|
|
||||||
# http://mathworld.wolfram.com/GramPoint.html
|
|
||||||
g = 2*ctx.pi*ctx.exp(1+ctx.lambertw((8*n+1)/(8*ctx.e)))
|
|
||||||
return ctx.findroot(lambda t: ctx.siegeltheta(t)-ctx.pi*n, g)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def siegelz(ctx, t):
|
|
||||||
v = ctx.expj(ctx.siegeltheta(t))*ctx.zeta(0.5+ctx.j*t)
|
|
||||||
if ctx._is_real_type(t):
|
|
||||||
return ctx._re(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
_zeta_zeros = [
|
|
||||||
14.134725142,21.022039639,25.010857580,30.424876126,32.935061588,
|
|
||||||
37.586178159,40.918719012,43.327073281,48.005150881,49.773832478,
|
|
||||||
52.970321478,56.446247697,59.347044003,60.831778525,65.112544048,
|
|
||||||
67.079810529,69.546401711,72.067157674,75.704690699,77.144840069,
|
|
||||||
79.337375020,82.910380854,84.735492981,87.425274613,88.809111208,
|
|
||||||
92.491899271,94.651344041,95.870634228,98.831194218,101.317851006,
|
|
||||||
103.725538040,105.446623052,107.168611184,111.029535543,111.874659177,
|
|
||||||
114.320220915,116.226680321,118.790782866,121.370125002,122.946829294,
|
|
||||||
124.256818554,127.516683880,129.578704200,131.087688531,133.497737203,
|
|
||||||
134.756509753,138.116042055,139.736208952,141.123707404,143.111845808,
|
|
||||||
146.000982487,147.422765343,150.053520421,150.925257612,153.024693811,
|
|
||||||
156.112909294,157.597591818,158.849988171,161.188964138,163.030709687,
|
|
||||||
165.537069188,167.184439978,169.094515416,169.911976479,173.411536520,
|
|
||||||
174.754191523,176.441434298,178.377407776,179.916484020,182.207078484,
|
|
||||||
184.874467848,185.598783678,187.228922584,189.416158656,192.026656361,
|
|
||||||
193.079726604,195.265396680,196.876481841,198.015309676,201.264751944,
|
|
||||||
202.493594514,204.189671803,205.394697202,207.906258888,209.576509717,
|
|
||||||
211.690862595,213.347919360,214.547044783,216.169538508,219.067596349,
|
|
||||||
220.714918839,221.430705555,224.007000255,224.983324670,227.421444280,
|
|
||||||
229.337413306,231.250188700,231.987235253,233.693404179,236.524229666,
|
|
||||||
]
|
|
||||||
|
|
||||||
def _load_zeta_zeros(url):
|
|
||||||
import urllib
|
|
||||||
d = urllib.urlopen(url)
|
|
||||||
L = [float(x) for x in d.readlines()]
|
|
||||||
# Sanity check
|
|
||||||
assert round(L[0]) == 14
|
|
||||||
_zeta_zeros[:] = L
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def zetazero(ctx, n, url='http://www.dtc.umn.edu/~odlyzko/zeta_tables/zeros1'):
|
|
||||||
n = int(n)
|
|
||||||
if n < 0:
|
|
||||||
return ctx.zetazero(-n).conjugate()
|
|
||||||
if n == 0:
|
|
||||||
raise ValueError("n must be nonzero")
|
|
||||||
if n > len(_zeta_zeros) and n <= 100000:
|
|
||||||
_load_zeta_zeros(url)
|
|
||||||
if n > len(_zeta_zeros):
|
|
||||||
raise NotImplementedError("n too large for zetazeros")
|
|
||||||
return ctx.mpc(0.5, ctx.findroot(ctx.siegelz, _zeta_zeros[n-1]))
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def riemannr(ctx, x):
|
|
||||||
if x == 0:
|
|
||||||
return ctx.zero
|
|
||||||
# Check if a simple asymptotic estimate is accurate enough
|
|
||||||
if abs(x) > 1000:
|
|
||||||
a = ctx.li(x)
|
|
||||||
b = 0.5*ctx.li(ctx.sqrt(x))
|
|
||||||
if abs(b) < abs(a)*ctx.eps:
|
|
||||||
return a
|
|
||||||
if abs(x) < 0.01:
|
|
||||||
# XXX
|
|
||||||
ctx.prec += int(-ctx.log(abs(x),2))
|
|
||||||
# Sum Gram's series
|
|
||||||
s = t = ctx.one
|
|
||||||
u = ctx.ln(x)
|
|
||||||
k = 1
|
|
||||||
while abs(t) > abs(s)*ctx.eps:
|
|
||||||
t = t * u / k
|
|
||||||
s += t / (k * ctx._zeta_int(k+1))
|
|
||||||
k += 1
|
|
||||||
return s
|
|
||||||
|
|
||||||
@defun_static
|
|
||||||
def primepi(ctx, x):
|
|
||||||
x = int(x)
|
|
||||||
if x < 2:
|
|
||||||
return 0
|
|
||||||
return len(ctx.list_primes(x))
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def primepi2(ctx, x):
|
|
||||||
x = int(x)
|
|
||||||
if x < 2:
|
|
||||||
return ctx.mpi(0,0)
|
|
||||||
if x < 2657:
|
|
||||||
return ctx.mpi(ctx.primepi(x))
|
|
||||||
mid = ctx.li(x)
|
|
||||||
# Schoenfeld's estimate for x >= 2657, assuming RH
|
|
||||||
err = ctx.sqrt(x,rounding='u')*ctx.ln(x,rounding='u')/8/ctx.pi(rounding='d')
|
|
||||||
a = ctx.floor((ctx.mpi(mid)-err).a, rounding='d')
|
|
||||||
b = ctx.ceil((ctx.mpi(mid)+err).b, rounding='u')
|
|
||||||
return ctx.mpi(a, b)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def primezeta(ctx, s):
|
|
||||||
if ctx.isnan(s):
|
|
||||||
return s
|
|
||||||
if ctx.re(s) <= 0:
|
|
||||||
raise ValueError("prime zeta function defined only for re(s) > 0")
|
|
||||||
if s == 1:
|
|
||||||
return ctx.inf
|
|
||||||
if s == 0.5:
|
|
||||||
return ctx.mpc(ctx.ninf, ctx.pi)
|
|
||||||
r = ctx.re(s)
|
|
||||||
if r > ctx.prec:
|
|
||||||
return 0.5**s
|
|
||||||
else:
|
|
||||||
wp = ctx.prec + int(r)
|
|
||||||
def terms():
|
|
||||||
orig = ctx.prec
|
|
||||||
# zeta ~ 1+eps; need to set precision
|
|
||||||
# to get logarithm accurately
|
|
||||||
k = 0
|
|
||||||
while 1:
|
|
||||||
k += 1
|
|
||||||
u = ctx.moebius(k)
|
|
||||||
if not u:
|
|
||||||
continue
|
|
||||||
ctx.prec = wp
|
|
||||||
t = u*ctx.ln(ctx.zeta(k*s))/k
|
|
||||||
if not t:
|
|
||||||
return
|
|
||||||
#print ctx.prec, ctx.nstr(t)
|
|
||||||
ctx.prec = orig
|
|
||||||
yield t
|
|
||||||
return ctx.sum_accurately(terms)
|
|
||||||
|
|
||||||
# TODO: for bernpoly and eulerpoly, ensure that all exact zeros are covered
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def bernpoly(ctx, n, z):
|
|
||||||
# Slow implementation:
|
|
||||||
#return sum(ctx.binomial(n,k)*ctx.bernoulli(k)*z**(n-k) for k in xrange(0,n+1))
|
|
||||||
n = int(n)
|
|
||||||
if n < 0:
|
|
||||||
raise ValueError("Bernoulli polynomials only defined for n >= 0")
|
|
||||||
if z == 0 or (z == 1 and n > 1):
|
|
||||||
return ctx.bernoulli(n)
|
|
||||||
if z == 0.5:
|
|
||||||
return (ctx.ldexp(1,1-n)-1)*ctx.bernoulli(n)
|
|
||||||
if n <= 3:
|
|
||||||
if n == 0: return z ** 0
|
|
||||||
if n == 1: return z - 0.5
|
|
||||||
if n == 2: return (6*z*(z-1)+1)/6
|
|
||||||
if n == 3: return z*(z*(z-1.5)+0.5)
|
|
||||||
if abs(z) == ctx.inf:
|
|
||||||
return z ** n
|
|
||||||
if z != z:
|
|
||||||
return z
|
|
||||||
if abs(z) > 2:
|
|
||||||
def terms():
|
|
||||||
t = ctx.one
|
|
||||||
yield t
|
|
||||||
r = ctx.one/z
|
|
||||||
k = 1
|
|
||||||
while k <= n:
|
|
||||||
t = t*(n+1-k)/k*r
|
|
||||||
if not (k > 2 and k & 1):
|
|
||||||
yield t*ctx.bernoulli(k)
|
|
||||||
k += 1
|
|
||||||
return ctx.sum_accurately(terms) * z**n
|
|
||||||
else:
|
|
||||||
def terms():
|
|
||||||
yield ctx.bernoulli(n)
|
|
||||||
t = ctx.one
|
|
||||||
k = 1
|
|
||||||
while k <= n:
|
|
||||||
t = t*(n+1-k)/k * z
|
|
||||||
m = n-k
|
|
||||||
if not (m > 2 and m & 1):
|
|
||||||
yield t*ctx.bernoulli(m)
|
|
||||||
k += 1
|
|
||||||
return ctx.sum_accurately(terms)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def eulerpoly(ctx, n, z):
|
|
||||||
n = int(n)
|
|
||||||
if n < 0:
|
|
||||||
raise ValueError("Euler polynomials only defined for n >= 0")
|
|
||||||
if n <= 2:
|
|
||||||
if n == 0: return z ** 0
|
|
||||||
if n == 1: return z - 0.5
|
|
||||||
if n == 2: return z*(z-1)
|
|
||||||
if abs(z) == ctx.inf:
|
|
||||||
return z**n
|
|
||||||
if z != z:
|
|
||||||
return z
|
|
||||||
m = n+1
|
|
||||||
if z == 0:
|
|
||||||
return -2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
|
|
||||||
if z == 1:
|
|
||||||
return 2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
|
|
||||||
if z == 0.5:
|
|
||||||
if n % 2:
|
|
||||||
return ctx.zero
|
|
||||||
# Use exact code for Euler numbers
|
|
||||||
if n < 100 or n*ctx.mag(0.46839865*n) < ctx.prec*0.25:
|
|
||||||
return ctx.ldexp(ctx._eulernum(n), -n)
|
|
||||||
# http://functions.wolfram.com/Polynomials/EulerE2/06/01/02/01/0002/
|
|
||||||
def terms():
|
|
||||||
t = ctx.one
|
|
||||||
k = 0
|
|
||||||
w = ctx.ldexp(1,n+2)
|
|
||||||
while 1:
|
|
||||||
v = n-k+1
|
|
||||||
if not (v > 2 and v & 1):
|
|
||||||
yield (2-w)*ctx.bernoulli(v)*t
|
|
||||||
k += 1
|
|
||||||
if k > n:
|
|
||||||
break
|
|
||||||
t = t*z*(n-k+2)/k
|
|
||||||
w *= 0.5
|
|
||||||
return ctx.sum_accurately(terms) / m
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def eulernum(ctx, n, exact=False):
|
|
||||||
n = int(n)
|
|
||||||
if exact:
|
|
||||||
return int(ctx._eulernum(n))
|
|
||||||
if n < 100:
|
|
||||||
return ctx.mpf(ctx._eulernum(n))
|
|
||||||
if n % 2:
|
|
||||||
return ctx.zero
|
|
||||||
return ctx.ldexp(ctx.eulerpoly(n,0.5), n)
|
|
||||||
|
|
||||||
# TODO: this should be implemented low-level
|
|
||||||
def polylog_series(ctx, s, z):
|
|
||||||
tol = +ctx.eps
|
|
||||||
l = ctx.zero
|
|
||||||
k = 1
|
|
||||||
zk = z
|
|
||||||
while 1:
|
|
||||||
term = zk / k**s
|
|
||||||
l += term
|
|
||||||
if abs(term) < tol:
|
|
||||||
break
|
|
||||||
zk *= z
|
|
||||||
k += 1
|
|
||||||
return l
|
|
||||||
|
|
||||||
def polylog_continuation(ctx, n, z):
|
|
||||||
if n < 0:
|
|
||||||
return z*0
|
|
||||||
twopij = 2j * ctx.pi
|
|
||||||
a = -twopij**n/ctx.fac(n) * ctx.bernpoly(n, ctx.ln(z)/twopij)
|
|
||||||
if ctx._is_real_type(z) and z < 0:
|
|
||||||
a = ctx._re(a)
|
|
||||||
if ctx._im(z) < 0 or (ctx._im(z) == 0 and ctx._re(z) >= 1):
|
|
||||||
a -= twopij*ctx.ln(z)**(n-1)/ctx.fac(n-1)
|
|
||||||
return a
|
|
||||||
|
|
||||||
def polylog_unitcircle(ctx, n, z):
|
|
||||||
tol = +ctx.eps
|
|
||||||
if n > 1:
|
|
||||||
l = ctx.zero
|
|
||||||
logz = ctx.ln(z)
|
|
||||||
logmz = ctx.one
|
|
||||||
m = 0
|
|
||||||
while 1:
|
|
||||||
if (n-m) != 1:
|
|
||||||
term = ctx.zeta(n-m) * logmz / ctx.fac(m)
|
|
||||||
if term and abs(term) < tol:
|
|
||||||
break
|
|
||||||
l += term
|
|
||||||
logmz *= logz
|
|
||||||
m += 1
|
|
||||||
l += ctx.ln(z)**(n-1)/ctx.fac(n-1)*(ctx.harmonic(n-1)-ctx.ln(-ctx.ln(z)))
|
|
||||||
elif n < 1: # else
|
|
||||||
l = ctx.fac(-n)*(-ctx.ln(z))**(n-1)
|
|
||||||
logz = ctx.ln(z)
|
|
||||||
logkz = ctx.one
|
|
||||||
k = 0
|
|
||||||
while 1:
|
|
||||||
b = ctx.bernoulli(k-n+1)
|
|
||||||
if b:
|
|
||||||
term = b*logkz/(ctx.fac(k)*(k-n+1))
|
|
||||||
if abs(term) < tol:
|
|
||||||
break
|
|
||||||
l -= term
|
|
||||||
logkz *= logz
|
|
||||||
k += 1
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
if ctx._is_real_type(z) and z < 0:
|
|
||||||
l = ctx._re(l)
|
|
||||||
return l
|
|
||||||
|
|
||||||
def polylog_general(ctx, s, z):
|
|
||||||
v = ctx.zero
|
|
||||||
u = ctx.ln(z)
|
|
||||||
if not abs(u) < 5: # theoretically |u| < 2*pi
|
|
||||||
raise NotImplementedError("polylog for arbitrary s and z")
|
|
||||||
t = 1
|
|
||||||
k = 0
|
|
||||||
while 1:
|
|
||||||
term = ctx.zeta(s-k) * t
|
|
||||||
if abs(term) < ctx.eps:
|
|
||||||
break
|
|
||||||
v += term
|
|
||||||
k += 1
|
|
||||||
t *= u
|
|
||||||
t /= k
|
|
||||||
return ctx.gamma(1-s)*(-u)**(s-1) + v
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def polylog(ctx, s, z):
|
|
||||||
s = ctx.convert(s)
|
|
||||||
z = ctx.convert(z)
|
|
||||||
if z == 1:
|
|
||||||
return ctx.zeta(s)
|
|
||||||
if z == -1:
|
|
||||||
return -ctx.altzeta(s)
|
|
||||||
if s == 0:
|
|
||||||
return z/(1-z)
|
|
||||||
if s == 1:
|
|
||||||
return -ctx.ln(1-z)
|
|
||||||
if s == -1:
|
|
||||||
return z/(1-z)**2
|
|
||||||
if abs(z) <= 0.75 or (not ctx.isint(s) and abs(z) < 0.9):
|
|
||||||
return polylog_series(ctx, s, z)
|
|
||||||
if abs(z) >= 1.4 and ctx.isint(s):
|
|
||||||
return (-1)**(s+1)*polylog_series(ctx, s, 1/z) + polylog_continuation(ctx, s, z)
|
|
||||||
if ctx.isint(s):
|
|
||||||
return polylog_unitcircle(ctx, int(s), z)
|
|
||||||
return polylog_general(ctx, s, z)
|
|
||||||
|
|
||||||
#raise NotImplementedError("polylog for arbitrary s and z")
|
|
||||||
# This could perhaps be used in some cases
|
|
||||||
#from quadrature import quad
|
|
||||||
#return quad(lambda t: t**(s-1)/(exp(t)/z-1),[0,inf])/gamma(s)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def clsin(ctx, s, z, pi=False):
|
|
||||||
if ctx.isint(s) and s < 0 and int(s) % 2 == 1:
|
|
||||||
return z*0
|
|
||||||
if pi:
|
|
||||||
a = ctx.expjpi(z)
|
|
||||||
else:
|
|
||||||
a = ctx.expj(z)
|
|
||||||
if ctx._is_real_type(z) and ctx._is_real_type(s):
|
|
||||||
return ctx.im(ctx.polylog(s,a))
|
|
||||||
b = 1/a
|
|
||||||
return (-0.5j)*(ctx.polylog(s,a) - ctx.polylog(s,b))
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def clcos(ctx, s, z, pi=False):
|
|
||||||
if ctx.isint(s) and s < 0 and int(s) % 2 == 0:
|
|
||||||
return z*0
|
|
||||||
if pi:
|
|
||||||
a = ctx.expjpi(z)
|
|
||||||
else:
|
|
||||||
a = ctx.expj(z)
|
|
||||||
if ctx._is_real_type(z) and ctx._is_real_type(s):
|
|
||||||
return ctx.re(ctx.polylog(s,a))
|
|
||||||
b = 1/a
|
|
||||||
return 0.5*(ctx.polylog(s,a) + ctx.polylog(s,b))
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def altzeta(ctx, s, **kwargs):
|
|
||||||
try:
|
|
||||||
return ctx._altzeta(s, **kwargs)
|
|
||||||
except NotImplementedError:
|
|
||||||
return ctx._altzeta_generic(s)
|
|
||||||
|
|
||||||
@defun_wrapped
|
|
||||||
def _altzeta_generic(ctx, s):
|
|
||||||
if s == 1:
|
|
||||||
return ctx.ln2 + 0*s
|
|
||||||
return -ctx.powm1(2, 1-s) * ctx.zeta(s)
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def zeta(ctx, s, a=1, derivative=0, method=None, **kwargs):
|
|
||||||
d = int(derivative)
|
|
||||||
if a == 1 and not (d or method):
|
|
||||||
try:
|
|
||||||
return ctx._zeta(s, **kwargs)
|
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
s = ctx.convert(s)
|
|
||||||
prec = ctx.prec
|
|
||||||
method = kwargs.get('method')
|
|
||||||
verbose = kwargs.get('verbose')
|
|
||||||
if a == 1 and method != 'euler-maclaurin':
|
|
||||||
im = abs(ctx._im(s))
|
|
||||||
re = abs(ctx._re(s))
|
|
||||||
#if (im < prec or method == 'borwein') and not derivative:
|
|
||||||
# try:
|
|
||||||
# if verbose:
|
|
||||||
# print "zeta: Attempting to use the Borwein algorithm"
|
|
||||||
# return ctx._zeta(s, **kwargs)
|
|
||||||
# except NotImplementedError:
|
|
||||||
# if verbose:
|
|
||||||
# print "zeta: Could not use the Borwein algorithm"
|
|
||||||
# pass
|
|
||||||
if abs(im) > 60*prec and 10*re < prec and derivative <= 4 or \
|
|
||||||
method == 'riemann-siegel':
|
|
||||||
try: # py2.4 compatible try block
|
|
||||||
try:
|
|
||||||
if verbose:
|
|
||||||
print "zeta: Attempting to use the Riemann-Siegel algorithm"
|
|
||||||
return ctx.rs_zeta(s, derivative, **kwargs)
|
|
||||||
except NotImplementedError:
|
|
||||||
if verbose:
|
|
||||||
print "zeta: Could not use the Riemann-Siegel algorithm"
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
if s == 1:
|
|
||||||
return ctx.inf
|
|
||||||
abss = abs(s)
|
|
||||||
if abss == ctx.inf:
|
|
||||||
if ctx.re(s) == ctx.inf:
|
|
||||||
if d == 0:
|
|
||||||
return ctx.one
|
|
||||||
return ctx.zero
|
|
||||||
return s*0
|
|
||||||
elif ctx.isnan(abss):
|
|
||||||
return 1/s
|
|
||||||
if ctx.re(s) > 2*ctx.prec and a == 1 and not derivative:
|
|
||||||
return ctx.one + ctx.power(2, -s)
|
|
||||||
if verbose:
|
|
||||||
print "zeta: Using the Euler-Maclaurin algorithm"
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
v = ctx._hurwitz(s, a, d)
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return +v
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def _hurwitz(ctx, s, a=1, d=0):
|
|
||||||
# We strongly want to special-case rational a
|
|
||||||
a, atype = ctx._convert_param(a)
|
|
||||||
prec = ctx.prec
|
|
||||||
# TODO: implement reflection for derivatives
|
|
||||||
res = ctx.re(s)
|
|
||||||
negs = -s
|
|
||||||
try:
|
|
||||||
if res < 0 and not d:
|
|
||||||
# Integer reflection formula
|
|
||||||
if ctx.isnpint(s):
|
|
||||||
n = int(res)
|
|
||||||
if n <= 0:
|
|
||||||
return ctx.bernpoly(1-n, a) / (n-1)
|
|
||||||
t = 1-s
|
|
||||||
# We now require a to be standardized
|
|
||||||
v = 0
|
|
||||||
shift = 0
|
|
||||||
b = a
|
|
||||||
while ctx.re(b) > 1:
|
|
||||||
b -= 1
|
|
||||||
v -= b**negs
|
|
||||||
shift -= 1
|
|
||||||
while ctx.re(b) <= 0:
|
|
||||||
v += b**negs
|
|
||||||
b += 1
|
|
||||||
shift += 1
|
|
||||||
# Rational reflection formula
|
|
||||||
if atype == 'Q' or atype == 'Z':
|
|
||||||
try:
|
|
||||||
p, q = a
|
|
||||||
except:
|
|
||||||
assert a == int(a)
|
|
||||||
p = int(a)
|
|
||||||
q = 1
|
|
||||||
p += shift*q
|
|
||||||
assert 1 <= p <= q
|
|
||||||
g = ctx.fsum(ctx.cospi(t/2-2*k*b)*ctx._hurwitz(t,(k,q)) \
|
|
||||||
for k in range(1,q+1))
|
|
||||||
g *= 2*ctx.gamma(t)/(2*ctx.pi*q)**t
|
|
||||||
v += g
|
|
||||||
return v
|
|
||||||
# General reflection formula
|
|
||||||
else:
|
|
||||||
C1 = ctx.cospi(t/2)
|
|
||||||
C2 = ctx.sinpi(t/2)
|
|
||||||
# Clausen functions; could maybe use polylog directly
|
|
||||||
if C1: C1 *= ctx.clcos(t, 2*a, pi=True)
|
|
||||||
if C2: C2 *= ctx.clsin(t, 2*a, pi=True)
|
|
||||||
v += 2*ctx.gamma(t)/(2*ctx.pi)**t*(C1+C2)
|
|
||||||
return v
|
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
a = ctx.convert(a)
|
|
||||||
tol = -prec
|
|
||||||
# Estimate number of terms for Euler-Maclaurin summation; could be improved
|
|
||||||
M1 = 0
|
|
||||||
M2 = prec // 3
|
|
||||||
N = M2
|
|
||||||
lsum = 0
|
|
||||||
# This speeds up the recurrence for derivatives
|
|
||||||
if ctx.isint(s):
|
|
||||||
s = int(ctx._re(s))
|
|
||||||
s1 = s-1
|
|
||||||
while 1:
|
|
||||||
# Truncated L-series
|
|
||||||
l = ctx._zetasum(s, M1+a, M2-M1-1, [d])[0][0]
|
|
||||||
#if d:
|
|
||||||
# l = ctx.fsum((-ctx.ln(n+a))**d * (n+a)**negs for n in range(M1,M2))
|
|
||||||
#else:
|
|
||||||
# l = ctx.fsum((n+a)**negs for n in range(M1,M2))
|
|
||||||
lsum += l
|
|
||||||
M2a = M2+a
|
|
||||||
logM2a = ctx.ln(M2a)
|
|
||||||
logM2ad = logM2a**d
|
|
||||||
logs = [logM2ad]
|
|
||||||
logr = 1/logM2a
|
|
||||||
rM2a = 1/M2a
|
|
||||||
M2as = rM2a**s
|
|
||||||
if d:
|
|
||||||
tailsum = ctx.gammainc(d+1, s1*logM2a) / s1**(d+1)
|
|
||||||
else:
|
|
||||||
tailsum = 1/((s1)*(M2a)**s1)
|
|
||||||
tailsum += 0.5 * logM2ad * M2as
|
|
||||||
U = [1]
|
|
||||||
r = M2as
|
|
||||||
fact = 2
|
|
||||||
for j in range(1, N+1):
|
|
||||||
# TODO: the following could perhaps be tidied a bit
|
|
||||||
j2 = 2*j
|
|
||||||
if j == 1:
|
|
||||||
upds = [1]
|
|
||||||
else:
|
|
||||||
upds = [j2-2, j2-1]
|
|
||||||
for m in upds:
|
|
||||||
D = min(m,d+1)
|
|
||||||
if m <= d:
|
|
||||||
logs.append(logs[-1] * logr)
|
|
||||||
Un = [0]*(D+1)
|
|
||||||
for i in xrange(D): Un[i] = (1-m-s)*U[i]
|
|
||||||
for i in xrange(1,D+1): Un[i] += (d-(i-1))*U[i-1]
|
|
||||||
U = Un
|
|
||||||
r *= rM2a
|
|
||||||
t = ctx.fdot(U, logs) * r * ctx.bernoulli(j2)/(-fact)
|
|
||||||
tailsum += t
|
|
||||||
if ctx.mag(t) < tol:
|
|
||||||
return lsum + (-1)**d * tailsum
|
|
||||||
fact *= (j2+1)*(j2+2)
|
|
||||||
M1, M2 = M2, M2*2
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def _zetasum(ctx, s, a, n, derivatives=[0], reflect=False):
|
|
||||||
"""
|
|
||||||
Returns [xd0,xd1,...,xdr], [yd0,yd1,...ydr] where
|
|
||||||
|
|
||||||
xdk = D^k ( 1/a^s + 1/(a+1)^s + ... + 1/(a+n)^s )
|
|
||||||
ydk = D^k conj( 1/a^(1-s) + 1/(a+1)^(1-s) + ... + 1/(a+n)^(1-s) )
|
|
||||||
|
|
||||||
D^k = kth derivative with respect to s, k ranges over the given list of
|
|
||||||
derivatives (which should consist of either a single element
|
|
||||||
or a range 0,1,...r). If reflect=False, the ydks are not computed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return ctx._zetasum_fast(s, a, n, derivatives, reflect)
|
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
negs = ctx.fneg(s, exact=True)
|
|
||||||
have_derivatives = derivatives != [0]
|
|
||||||
have_one_derivative = len(derivatives) == 1
|
|
||||||
if not reflect:
|
|
||||||
if not have_derivatives:
|
|
||||||
return [ctx.fsum((a+k)**negs for k in xrange(n+1))], []
|
|
||||||
if have_one_derivative:
|
|
||||||
d = derivatives[0]
|
|
||||||
x = ctx.fsum(ctx.ln(a+k)**d * (a+k)**negs for k in xrange(n+1))
|
|
||||||
return [(-1)**d * x], []
|
|
||||||
maxd = max(derivatives)
|
|
||||||
if not have_one_derivative:
|
|
||||||
derivatives = range(maxd+1)
|
|
||||||
xs = [ctx.zero for d in derivatives]
|
|
||||||
if reflect:
|
|
||||||
ys = [ctx.zero for d in derivatives]
|
|
||||||
else:
|
|
||||||
ys = []
|
|
||||||
for k in xrange(n+1):
|
|
||||||
w = a + k
|
|
||||||
xterm = w ** negs
|
|
||||||
if reflect:
|
|
||||||
yterm = ctx.conj(ctx.one / (w * xterm))
|
|
||||||
if have_derivatives:
|
|
||||||
logw = -ctx.ln(w)
|
|
||||||
if have_one_derivative:
|
|
||||||
logw = logw ** maxd
|
|
||||||
xs[0] += xterm * logw
|
|
||||||
if reflect:
|
|
||||||
ys[0] += yterm * logw
|
|
||||||
else:
|
|
||||||
t = ctx.one
|
|
||||||
for d in derivatives:
|
|
||||||
xs[d] += xterm * t
|
|
||||||
if reflect:
|
|
||||||
ys[d] += yterm * t
|
|
||||||
t *= logw
|
|
||||||
else:
|
|
||||||
xs[0] += xterm
|
|
||||||
if reflect:
|
|
||||||
ys[0] += yterm
|
|
||||||
return xs, ys
|
|
||||||
|
|
||||||
@defun
|
|
||||||
def dirichlet(ctx, s, chi=[1], derivative=0):
|
|
||||||
s = ctx.convert(s)
|
|
||||||
q = len(chi)
|
|
||||||
d = int(derivative)
|
|
||||||
if d > 2:
|
|
||||||
raise NotImplementedError("arbitrary order derivatives")
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
if s == 1:
|
|
||||||
have_pole = True
|
|
||||||
for x in chi:
|
|
||||||
if x and x != 1:
|
|
||||||
have_pole = False
|
|
||||||
h = +ctx.eps
|
|
||||||
ctx.prec *= 2*(d+1)
|
|
||||||
s += h
|
|
||||||
if have_pole:
|
|
||||||
return +ctx.inf
|
|
||||||
z = ctx.zero
|
|
||||||
for p in range(1,q+1):
|
|
||||||
if chi[p%q]:
|
|
||||||
if d == 1:
|
|
||||||
z += chi[p%q] * (ctx.zeta(s, (p,q), 1) - \
|
|
||||||
ctx.zeta(s, (p,q))*ctx.log(q))
|
|
||||||
else:
|
|
||||||
z += chi[p%q] * ctx.zeta(s, (p,q))
|
|
||||||
z /= q**s
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return +z
|
|
||||||
|
|
@ -1,840 +0,0 @@
|
||||||
"""
|
|
||||||
Implements the PSLQ algorithm for integer relation detection,
|
|
||||||
and derivative algorithms for constant recognition.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from libmp import int_types, sqrt_fixed
|
|
||||||
|
|
||||||
# round to nearest integer (can be done more elegantly...)
|
|
||||||
def round_fixed(x, prec):
|
|
||||||
return ((x + (1<<(prec-1))) >> prec) << prec
|
|
||||||
|
|
||||||
class IdentificationMethods(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def pslq(ctx, x, tol=None, maxcoeff=1000, maxsteps=100, verbose=False):
|
|
||||||
r"""
|
|
||||||
Given a vector of real numbers `x = [x_0, x_1, ..., x_n]`, ``pslq(x)``
|
|
||||||
uses the PSLQ algorithm to find a list of integers
|
|
||||||
`[c_0, c_1, ..., c_n]` such that
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
|c_1 x_1 + c_2 x_2 + ... + c_n x_n| < \mathrm{tol}
|
|
||||||
|
|
||||||
and such that `\max |c_k| < \mathrm{maxcoeff}`. If no such vector
|
|
||||||
exists, :func:`pslq` returns ``None``. The tolerance defaults to
|
|
||||||
3/4 of the working precision.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Find rational approximations for `\pi`::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> pslq([-1, pi], tol=0.01)
|
|
||||||
[22, 7]
|
|
||||||
>>> pslq([-1, pi], tol=0.001)
|
|
||||||
[355, 113]
|
|
||||||
>>> mpf(22)/7; mpf(355)/113; +pi
|
|
||||||
3.14285714285714
|
|
||||||
3.14159292035398
|
|
||||||
3.14159265358979
|
|
||||||
|
|
||||||
Pi is not a rational number with denominator less than 1000::
|
|
||||||
|
|
||||||
>>> pslq([-1, pi])
|
|
||||||
>>>
|
|
||||||
|
|
||||||
To within the standard precision, it can however be approximated
|
|
||||||
by at least one rational number with denominator less than `10^{12}`::
|
|
||||||
|
|
||||||
>>> p, q = pslq([-1, pi], maxcoeff=10**12)
|
|
||||||
>>> print p, q
|
|
||||||
238410049439 75888275702
|
|
||||||
>>> mpf(p)/q
|
|
||||||
3.14159265358979
|
|
||||||
|
|
||||||
The PSLQ algorithm can be applied to long vectors. For example,
|
|
||||||
we can investigate the rational (in)dependence of integer square
|
|
||||||
roots::
|
|
||||||
|
|
||||||
>>> mp.dps = 30
|
|
||||||
>>> pslq([sqrt(n) for n in range(2, 5+1)])
|
|
||||||
>>>
|
|
||||||
>>> pslq([sqrt(n) for n in range(2, 6+1)])
|
|
||||||
>>>
|
|
||||||
>>> pslq([sqrt(n) for n in range(2, 8+1)])
|
|
||||||
[2, 0, 0, 0, 0, 0, -1]
|
|
||||||
|
|
||||||
**Machin formulas**
|
|
||||||
|
|
||||||
A famous formula for `\pi` is Machin's,
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\frac{\pi}{4} = 4 \operatorname{acot} 5 - \operatorname{acot} 239
|
|
||||||
|
|
||||||
There are actually infinitely many formulas of this type. Two
|
|
||||||
others are
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\frac{\pi}{4} = \operatorname{acot} 1
|
|
||||||
|
|
||||||
\frac{\pi}{4} = 12 \operatorname{acot} 49 + 32 \operatorname{acot} 57
|
|
||||||
+ 5 \operatorname{acot} 239 + 12 \operatorname{acot} 110443
|
|
||||||
|
|
||||||
We can easily verify the formulas using the PSLQ algorithm::
|
|
||||||
|
|
||||||
>>> mp.dps = 30
|
|
||||||
>>> pslq([pi/4, acot(1)])
|
|
||||||
[1, -1]
|
|
||||||
>>> pslq([pi/4, acot(5), acot(239)])
|
|
||||||
[1, -4, 1]
|
|
||||||
>>> pslq([pi/4, acot(49), acot(57), acot(239), acot(110443)])
|
|
||||||
[1, -12, -32, 5, -12]
|
|
||||||
|
|
||||||
We could try to generate a custom Machin-like formula by running
|
|
||||||
the PSLQ algorithm with a few inverse cotangent values, for example
|
|
||||||
acot(2), acot(3) ... acot(10). Unfortunately, there is a linear
|
|
||||||
dependence among these values, resulting in only that dependence
|
|
||||||
being detected, with a zero coefficient for `\pi`::
|
|
||||||
|
|
||||||
>>> pslq([pi] + [acot(n) for n in range(2,11)])
|
|
||||||
[0, 1, -1, 0, 0, 0, -1, 0, 0, 0]
|
|
||||||
|
|
||||||
We get better luck by removing linearly dependent terms::
|
|
||||||
|
|
||||||
>>> pslq([pi] + [acot(n) for n in range(2,11) if n not in (3, 5)])
|
|
||||||
[1, -8, 0, 0, 4, 0, 0, 0]
|
|
||||||
|
|
||||||
In other words, we found the following formula::
|
|
||||||
|
|
||||||
>>> 8*acot(2) - 4*acot(7)
|
|
||||||
3.14159265358979323846264338328
|
|
||||||
>>> +pi
|
|
||||||
3.14159265358979323846264338328
|
|
||||||
|
|
||||||
**Algorithm**
|
|
||||||
|
|
||||||
This is a fairly direct translation to Python of the pseudocode given by
|
|
||||||
David Bailey, "The PSLQ Integer Relation Algorithm":
|
|
||||||
http://www.cecm.sfu.ca/organics/papers/bailey/paper/html/node3.html
|
|
||||||
|
|
||||||
The present implementation uses fixed-point instead of floating-point
|
|
||||||
arithmetic, since this is significantly (about 7x) faster.
|
|
||||||
"""
|
|
||||||
|
|
||||||
n = len(x)
|
|
||||||
assert n >= 2
|
|
||||||
|
|
||||||
# At too low precision, the algorithm becomes meaningless
|
|
||||||
prec = ctx.prec
|
|
||||||
assert prec >= 53
|
|
||||||
|
|
||||||
if verbose and prec // max(2,n) < 5:
|
|
||||||
print "Warning: precision for PSLQ may be too low"
|
|
||||||
|
|
||||||
target = int(prec * 0.75)
|
|
||||||
|
|
||||||
if tol is None:
|
|
||||||
tol = ctx.mpf(2)**(-target)
|
|
||||||
else:
|
|
||||||
tol = ctx.convert(tol)
|
|
||||||
|
|
||||||
extra = 60
|
|
||||||
prec += extra
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print "PSLQ using prec %i and tol %s" % (prec, ctx.nstr(tol))
|
|
||||||
|
|
||||||
tol = ctx.to_fixed(tol, prec)
|
|
||||||
assert tol
|
|
||||||
|
|
||||||
# Convert to fixed-point numbers. The dummy None is added so we can
|
|
||||||
# use 1-based indexing. (This just allows us to be consistent with
|
|
||||||
# Bailey's indexing. The algorithm is 100 lines long, so debugging
|
|
||||||
# a single wrong index can be painful.)
|
|
||||||
x = [None] + [ctx.to_fixed(ctx.mpf(xk), prec) for xk in x]
|
|
||||||
|
|
||||||
# Sanity check on magnitudes
|
|
||||||
minx = min(abs(xx) for xx in x[1:])
|
|
||||||
if not minx:
|
|
||||||
raise ValueError("PSLQ requires a vector of nonzero numbers")
|
|
||||||
if minx < tol//100:
|
|
||||||
if verbose:
|
|
||||||
print "STOPPING: (one number is too small)"
|
|
||||||
return None
|
|
||||||
|
|
||||||
g = sqrt_fixed((4<<prec)//3, prec)
|
|
||||||
A = {}
|
|
||||||
B = {}
|
|
||||||
H = {}
|
|
||||||
# Initialization
|
|
||||||
# step 1
|
|
||||||
for i in xrange(1, n+1):
|
|
||||||
for j in xrange(1, n+1):
|
|
||||||
A[i,j] = B[i,j] = (i==j) << prec
|
|
||||||
H[i,j] = 0
|
|
||||||
# step 2
|
|
||||||
s = [None] + [0] * n
|
|
||||||
for k in xrange(1, n+1):
|
|
||||||
t = 0
|
|
||||||
for j in xrange(k, n+1):
|
|
||||||
t += (x[j]**2 >> prec)
|
|
||||||
s[k] = sqrt_fixed(t, prec)
|
|
||||||
t = s[1]
|
|
||||||
y = x[:]
|
|
||||||
for k in xrange(1, n+1):
|
|
||||||
y[k] = (x[k] << prec) // t
|
|
||||||
s[k] = (s[k] << prec) // t
|
|
||||||
# step 3
|
|
||||||
for i in xrange(1, n+1):
|
|
||||||
for j in xrange(i+1, n):
|
|
||||||
H[i,j] = 0
|
|
||||||
if i <= n-1:
|
|
||||||
if s[i]:
|
|
||||||
H[i,i] = (s[i+1] << prec) // s[i]
|
|
||||||
else:
|
|
||||||
H[i,i] = 0
|
|
||||||
for j in range(1, i):
|
|
||||||
sjj1 = s[j]*s[j+1]
|
|
||||||
if sjj1:
|
|
||||||
H[i,j] = ((-y[i]*y[j])<<prec)//sjj1
|
|
||||||
else:
|
|
||||||
H[i,j] = 0
|
|
||||||
# step 4
|
|
||||||
for i in xrange(2, n+1):
|
|
||||||
for j in xrange(i-1, 0, -1):
|
|
||||||
#t = floor(H[i,j]/H[j,j] + 0.5)
|
|
||||||
if H[j,j]:
|
|
||||||
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
|
|
||||||
else:
|
|
||||||
#t = 0
|
|
||||||
continue
|
|
||||||
y[j] = y[j] + (t*y[i] >> prec)
|
|
||||||
for k in xrange(1, j+1):
|
|
||||||
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
|
|
||||||
for k in xrange(1, n+1):
|
|
||||||
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
|
|
||||||
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
|
|
||||||
# Main algorithm
|
|
||||||
for REP in range(maxsteps):
|
|
||||||
# Step 1
|
|
||||||
m = -1
|
|
||||||
szmax = -1
|
|
||||||
for i in range(1, n):
|
|
||||||
h = H[i,i]
|
|
||||||
sz = (g**i * abs(h)) >> (prec*(i-1))
|
|
||||||
if sz > szmax:
|
|
||||||
m = i
|
|
||||||
szmax = sz
|
|
||||||
# Step 2
|
|
||||||
y[m], y[m+1] = y[m+1], y[m]
|
|
||||||
tmp = {}
|
|
||||||
for i in xrange(1,n+1): H[m,i], H[m+1,i] = H[m+1,i], H[m,i]
|
|
||||||
for i in xrange(1,n+1): A[m,i], A[m+1,i] = A[m+1,i], A[m,i]
|
|
||||||
for i in xrange(1,n+1): B[i,m], B[i,m+1] = B[i,m+1], B[i,m]
|
|
||||||
# Step 3
|
|
||||||
if m <= n - 2:
|
|
||||||
t0 = sqrt_fixed((H[m,m]**2 + H[m,m+1]**2)>>prec, prec)
|
|
||||||
# A zero element probably indicates that the precision has
|
|
||||||
# been exhausted. XXX: this could be spurious, due to
|
|
||||||
# using fixed-point arithmetic
|
|
||||||
if not t0:
|
|
||||||
break
|
|
||||||
t1 = (H[m,m] << prec) // t0
|
|
||||||
t2 = (H[m,m+1] << prec) // t0
|
|
||||||
for i in xrange(m, n+1):
|
|
||||||
t3 = H[i,m]
|
|
||||||
t4 = H[i,m+1]
|
|
||||||
H[i,m] = (t1*t3+t2*t4) >> prec
|
|
||||||
H[i,m+1] = (-t2*t3+t1*t4) >> prec
|
|
||||||
# Step 4
|
|
||||||
for i in xrange(m+1, n+1):
|
|
||||||
for j in xrange(min(i-1, m+1), 0, -1):
|
|
||||||
try:
|
|
||||||
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
|
|
||||||
# Precision probably exhausted
|
|
||||||
except ZeroDivisionError:
|
|
||||||
break
|
|
||||||
y[j] = y[j] + ((t*y[i]) >> prec)
|
|
||||||
for k in xrange(1, j+1):
|
|
||||||
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
|
|
||||||
for k in xrange(1, n+1):
|
|
||||||
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
|
|
||||||
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
|
|
||||||
# Until a relation is found, the error typically decreases
|
|
||||||
# slowly (e.g. a factor 1-10) with each step TODO: we could
|
|
||||||
# compare err from two successive iterations. If there is a
|
|
||||||
# large drop (several orders of magnitude), that indicates a
|
|
||||||
# "high quality" relation was detected. Reporting this to
|
|
||||||
# the user somehow might be useful.
|
|
||||||
best_err = maxcoeff<<prec
|
|
||||||
for i in xrange(1, n+1):
|
|
||||||
err = abs(y[i])
|
|
||||||
# Maybe we are done?
|
|
||||||
if err < tol:
|
|
||||||
# We are done if the coefficients are acceptable
|
|
||||||
vec = [int(round_fixed(B[j,i], prec) >> prec) for j in \
|
|
||||||
range(1,n+1)]
|
|
||||||
if max(abs(v) for v in vec) < maxcoeff:
|
|
||||||
if verbose:
|
|
||||||
print "FOUND relation at iter %i/%i, error: %s" % \
|
|
||||||
(REP, maxsteps, ctx.nstr(err / ctx.mpf(2)**prec, 1))
|
|
||||||
return vec
|
|
||||||
best_err = min(err, best_err)
|
|
||||||
# Calculate a lower bound for the norm. We could do this
|
|
||||||
# more exactly (using the Euclidean norm) but there is probably
|
|
||||||
# no practical benefit.
|
|
||||||
recnorm = max(abs(h) for h in H.values())
|
|
||||||
if recnorm:
|
|
||||||
norm = ((1 << (2*prec)) // recnorm) >> prec
|
|
||||||
norm //= 100
|
|
||||||
else:
|
|
||||||
norm = ctx.inf
|
|
||||||
if verbose:
|
|
||||||
print "%i/%i: Error: %8s Norm: %s" % \
|
|
||||||
(REP, maxsteps, ctx.nstr(best_err / ctx.mpf(2)**prec, 1), norm)
|
|
||||||
if norm >= maxcoeff:
|
|
||||||
break
|
|
||||||
if verbose:
|
|
||||||
print "CANCELLING after step %i/%i." % (REP, maxsteps)
|
|
||||||
print "Could not find an integer relation. Norm bound: %s" % norm
|
|
||||||
return None
|
|
||||||
|
|
||||||
def findpoly(ctx, x, n=1, **kwargs):
|
|
||||||
r"""
|
|
||||||
``findpoly(x, n)`` returns the coefficients of an integer
|
|
||||||
polynomial `P` of degree at most `n` such that `P(x) \approx 0`.
|
|
||||||
If no polynomial having `x` as a root can be found,
|
|
||||||
:func:`findpoly` returns ``None``.
|
|
||||||
|
|
||||||
:func:`findpoly` works by successively calling :func:`pslq` with
|
|
||||||
the vectors `[1, x]`, `[1, x, x^2]`, `[1, x, x^2, x^3]`, ...,
|
|
||||||
`[1, x, x^2, .., x^n]` as input. Keyword arguments given to
|
|
||||||
:func:`findpoly` are forwarded verbatim to :func:`pslq`. In
|
|
||||||
particular, you can specify a tolerance for `P(x)` with ``tol``
|
|
||||||
and a maximum permitted coefficient size with ``maxcoeff``.
|
|
||||||
|
|
||||||
For large values of `n`, it is recommended to run :func:`findpoly`
|
|
||||||
at high precision; preferably 50 digits or more.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
By default (degree `n = 1`), :func:`findpoly` simply finds a linear
|
|
||||||
polynomial with a rational root::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> findpoly(0.7)
|
|
||||||
[-10, 7]
|
|
||||||
|
|
||||||
The generated coefficient list is valid input to ``polyval`` and
|
|
||||||
``polyroots``::
|
|
||||||
|
|
||||||
>>> nprint(polyval(findpoly(phi, 2), phi), 1)
|
|
||||||
-2.0e-16
|
|
||||||
>>> for r in polyroots(findpoly(phi, 2)):
|
|
||||||
... print r
|
|
||||||
...
|
|
||||||
-0.618033988749895
|
|
||||||
1.61803398874989
|
|
||||||
|
|
||||||
Numbers of the form `m + n \sqrt p` for integers `(m, n, p)` are
|
|
||||||
solutions to quadratic equations. As we find here, `1+\sqrt 2`
|
|
||||||
is a root of the polynomial `x^2 - 2x - 1`::
|
|
||||||
|
|
||||||
>>> findpoly(1+sqrt(2), 2)
|
|
||||||
[1, -2, -1]
|
|
||||||
>>> findroot(lambda x: x**2 - 2*x - 1, 1)
|
|
||||||
2.4142135623731
|
|
||||||
|
|
||||||
Despite only containing square roots, the following number results
|
|
||||||
in a polynomial of degree 4::
|
|
||||||
|
|
||||||
>>> findpoly(sqrt(2)+sqrt(3), 4)
|
|
||||||
[1, 0, -10, 0, 1]
|
|
||||||
|
|
||||||
In fact, `x^4 - 10x^2 + 1` is the *minimal polynomial* of
|
|
||||||
`r = \sqrt 2 + \sqrt 3`, meaning that a rational polynomial of
|
|
||||||
lower degree having `r` as a root does not exist. Given sufficient
|
|
||||||
precision, :func:`findpoly` will usually find the correct
|
|
||||||
minimal polynomial of a given algebraic number.
|
|
||||||
|
|
||||||
**Non-algebraic numbers**
|
|
||||||
|
|
||||||
If :func:`findpoly` fails to find a polynomial with given
|
|
||||||
coefficient size and tolerance constraints, that means no such
|
|
||||||
polynomial exists.
|
|
||||||
|
|
||||||
We can verify that `\pi` is not an algebraic number of degree 3 with
|
|
||||||
coefficients less than 1000::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> findpoly(pi, 3)
|
|
||||||
>>>
|
|
||||||
|
|
||||||
It is always possible to find an algebraic approximation of a number
|
|
||||||
using one (or several) of the following methods:
|
|
||||||
|
|
||||||
1. Increasing the permitted degree
|
|
||||||
2. Allowing larger coefficients
|
|
||||||
3. Reducing the tolerance
|
|
||||||
|
|
||||||
One example of each method is shown below::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> findpoly(pi, 4)
|
|
||||||
[95, -545, 863, -183, -298]
|
|
||||||
>>> findpoly(pi, 3, maxcoeff=10000)
|
|
||||||
[836, -1734, -2658, -457]
|
|
||||||
>>> findpoly(pi, 3, tol=1e-7)
|
|
||||||
[-4, 22, -29, -2]
|
|
||||||
|
|
||||||
It is unknown whether Euler's constant is transcendental (or even
|
|
||||||
irrational). We can use :func:`findpoly` to check that if is
|
|
||||||
an algebraic number, its minimal polynomial must have degree
|
|
||||||
at least 7 and a coefficient of magnitude at least 1000000::
|
|
||||||
|
|
||||||
>>> mp.dps = 200
|
|
||||||
>>> findpoly(euler, 6, maxcoeff=10**6, tol=1e-100, maxsteps=1000)
|
|
||||||
>>>
|
|
||||||
|
|
||||||
Note that the high precision and strict tolerance is necessary
|
|
||||||
for such high-degree runs, since otherwise unwanted low-accuracy
|
|
||||||
approximations will be detected. It may also be necessary to set
|
|
||||||
maxsteps high to prevent a premature exit (before the coefficient
|
|
||||||
bound has been reached). Running with ``verbose=True`` to get an
|
|
||||||
idea what is happening can be useful.
|
|
||||||
"""
|
|
||||||
x = ctx.mpf(x)
|
|
||||||
assert n >= 1
|
|
||||||
if x == 0:
|
|
||||||
return [1, 0]
|
|
||||||
xs = [ctx.mpf(1)]
|
|
||||||
for i in range(1,n+1):
|
|
||||||
xs.append(x**i)
|
|
||||||
a = ctx.pslq(xs, **kwargs)
|
|
||||||
if a is not None:
|
|
||||||
return a[::-1]
|
|
||||||
|
|
||||||
def fracgcd(p, q):
|
|
||||||
x, y = p, q
|
|
||||||
while y:
|
|
||||||
x, y = y, x % y
|
|
||||||
if x != 1:
|
|
||||||
p //= x
|
|
||||||
q //= x
|
|
||||||
if q == 1:
|
|
||||||
return p
|
|
||||||
return p, q
|
|
||||||
|
|
||||||
def pslqstring(r, constants):
|
|
||||||
q = r[0]
|
|
||||||
r = r[1:]
|
|
||||||
s = []
|
|
||||||
for i in range(len(r)):
|
|
||||||
p = r[i]
|
|
||||||
if p:
|
|
||||||
z = fracgcd(-p,q)
|
|
||||||
cs = constants[i][1]
|
|
||||||
if cs == '1':
|
|
||||||
cs = ''
|
|
||||||
else:
|
|
||||||
cs = '*' + cs
|
|
||||||
if isinstance(z, int_types):
|
|
||||||
if z > 0: term = str(z) + cs
|
|
||||||
else: term = ("(%s)" % z) + cs
|
|
||||||
else:
|
|
||||||
term = ("(%s/%s)" % z) + cs
|
|
||||||
s.append(term)
|
|
||||||
s = ' + '.join(s)
|
|
||||||
if '+' in s or '*' in s:
|
|
||||||
s = '(' + s + ')'
|
|
||||||
return s or '0'
|
|
||||||
|
|
||||||
def prodstring(r, constants):
|
|
||||||
q = r[0]
|
|
||||||
r = r[1:]
|
|
||||||
num = []
|
|
||||||
den = []
|
|
||||||
for i in range(len(r)):
|
|
||||||
p = r[i]
|
|
||||||
if p:
|
|
||||||
z = fracgcd(-p,q)
|
|
||||||
cs = constants[i][1]
|
|
||||||
if isinstance(z, int_types):
|
|
||||||
if abs(z) == 1: t = cs
|
|
||||||
else: t = '%s**%s' % (cs, abs(z))
|
|
||||||
([num,den][z<0]).append(t)
|
|
||||||
else:
|
|
||||||
t = '%s**(%s/%s)' % (cs, abs(z[0]), z[1])
|
|
||||||
([num,den][z[0]<0]).append(t)
|
|
||||||
num = '*'.join(num)
|
|
||||||
den = '*'.join(den)
|
|
||||||
if num and den: return "(%s)/(%s)" % (num, den)
|
|
||||||
if num: return num
|
|
||||||
if den: return "1/(%s)" % den
|
|
||||||
|
|
||||||
def quadraticstring(ctx,t,a,b,c):
|
|
||||||
if c < 0:
|
|
||||||
a,b,c = -a,-b,-c
|
|
||||||
u1 = (-b+ctx.sqrt(b**2-4*a*c))/(2*c)
|
|
||||||
u2 = (-b-ctx.sqrt(b**2-4*a*c))/(2*c)
|
|
||||||
if abs(u1-t) < abs(u2-t):
|
|
||||||
if b: s = '((%s+sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
|
|
||||||
else: s = '(sqrt(%s)/%s)' % (-4*a*c,2*c)
|
|
||||||
else:
|
|
||||||
if b: s = '((%s-sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
|
|
||||||
else: s = '(-sqrt(%s)/%s)' % (-4*a*c,2*c)
|
|
||||||
return s
|
|
||||||
|
|
||||||
# Transformation y = f(x,c), with inverse function x = f(y,c)
|
|
||||||
# The third entry indicates whether the transformation is
|
|
||||||
# redundant when c = 1
|
|
||||||
transforms = [
|
|
||||||
(lambda ctx,x,c: x*c, '$y/$c', 0),
|
|
||||||
(lambda ctx,x,c: x/c, '$c*$y', 1),
|
|
||||||
(lambda ctx,x,c: c/x, '$c/$y', 0),
|
|
||||||
(lambda ctx,x,c: (x*c)**2, 'sqrt($y)/$c', 0),
|
|
||||||
(lambda ctx,x,c: (x/c)**2, '$c*sqrt($y)', 1),
|
|
||||||
(lambda ctx,x,c: (c/x)**2, '$c/sqrt($y)', 0),
|
|
||||||
(lambda ctx,x,c: c*x**2, 'sqrt($y)/sqrt($c)', 1),
|
|
||||||
(lambda ctx,x,c: x**2/c, 'sqrt($c)*sqrt($y)', 1),
|
|
||||||
(lambda ctx,x,c: c/x**2, 'sqrt($c)/sqrt($y)', 1),
|
|
||||||
(lambda ctx,x,c: ctx.sqrt(x*c), '$y**2/$c', 0),
|
|
||||||
(lambda ctx,x,c: ctx.sqrt(x/c), '$c*$y**2', 1),
|
|
||||||
(lambda ctx,x,c: ctx.sqrt(c/x), '$c/$y**2', 0),
|
|
||||||
(lambda ctx,x,c: c*ctx.sqrt(x), '$y**2/$c**2', 1),
|
|
||||||
(lambda ctx,x,c: ctx.sqrt(x)/c, '$c**2*$y**2', 1),
|
|
||||||
(lambda ctx,x,c: c/ctx.sqrt(x), '$c**2/$y**2', 1),
|
|
||||||
(lambda ctx,x,c: ctx.exp(x*c), 'log($y)/$c', 0),
|
|
||||||
(lambda ctx,x,c: ctx.exp(x/c), '$c*log($y)', 1),
|
|
||||||
(lambda ctx,x,c: ctx.exp(c/x), '$c/log($y)', 0),
|
|
||||||
(lambda ctx,x,c: c*ctx.exp(x), 'log($y/$c)', 1),
|
|
||||||
(lambda ctx,x,c: ctx.exp(x)/c, 'log($c*$y)', 1),
|
|
||||||
(lambda ctx,x,c: c/ctx.exp(x), 'log($c/$y)', 0),
|
|
||||||
(lambda ctx,x,c: ctx.ln(x*c), 'exp($y)/$c', 0),
|
|
||||||
(lambda ctx,x,c: ctx.ln(x/c), '$c*exp($y)', 1),
|
|
||||||
(lambda ctx,x,c: ctx.ln(c/x), '$c/exp($y)', 0),
|
|
||||||
(lambda ctx,x,c: c*ctx.ln(x), 'exp($y/$c)', 1),
|
|
||||||
(lambda ctx,x,c: ctx.ln(x)/c, 'exp($c*$y)', 1),
|
|
||||||
(lambda ctx,x,c: c/ctx.ln(x), 'exp($c/$y)', 0),
|
|
||||||
]
|
|
||||||
|
|
||||||
def identify(ctx, x, constants=[], tol=None, maxcoeff=1000, full=False,
|
|
||||||
verbose=False):
|
|
||||||
"""
|
|
||||||
Given a real number `x`, ``identify(x)`` attempts to find an exact
|
|
||||||
formula for `x`. This formula is returned as a string. If no match
|
|
||||||
is found, ``None`` is returned. With ``full=True``, a list of
|
|
||||||
matching formulas is returned.
|
|
||||||
|
|
||||||
As a simple example, :func:`identify` will find an algebraic
|
|
||||||
formula for the golden ratio::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> identify(phi)
|
|
||||||
'((1+sqrt(5))/2)'
|
|
||||||
|
|
||||||
:func:`identify` can identify simple algebraic numbers and simple
|
|
||||||
combinations of given base constants, as well as certain basic
|
|
||||||
transformations thereof. More specifically, :func:`identify`
|
|
||||||
looks for the following:
|
|
||||||
|
|
||||||
1. Fractions
|
|
||||||
2. Quadratic algebraic numbers
|
|
||||||
3. Rational linear combinations of the base constants
|
|
||||||
4. Any of the above after first transforming `x` into `f(x)` where
|
|
||||||
`f(x)` is `1/x`, `\sqrt x`, `x^2`, `\log x` or `\exp x`, either
|
|
||||||
directly or with `x` or `f(x)` multiplied or divided by one of
|
|
||||||
the base constants
|
|
||||||
5. Products of fractional powers of the base constants and
|
|
||||||
small integers
|
|
||||||
|
|
||||||
Base constants can be given as a list of strings representing mpmath
|
|
||||||
expressions (:func:`identify` will ``eval`` the strings to numerical
|
|
||||||
values and use the original strings for the output), or as a dict of
|
|
||||||
formula:value pairs.
|
|
||||||
|
|
||||||
In order not to produce spurious results, :func:`identify` should
|
|
||||||
be used with high precision; preferrably 50 digits or more.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Simple identifications can be performed safely at standard
|
|
||||||
precision. Here the default recognition of rational, algebraic,
|
|
||||||
and exp/log of algebraic numbers is demonstrated::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> identify(0.22222222222222222)
|
|
||||||
'(2/9)'
|
|
||||||
>>> identify(1.9662210973805663)
|
|
||||||
'sqrt(((24+sqrt(48))/8))'
|
|
||||||
>>> identify(4.1132503787829275)
|
|
||||||
'exp((sqrt(8)/2))'
|
|
||||||
>>> identify(0.881373587019543)
|
|
||||||
'log(((2+sqrt(8))/2))'
|
|
||||||
|
|
||||||
By default, :func:`identify` does not recognize `\pi`. At standard
|
|
||||||
precision it finds a not too useful approximation. At slightly
|
|
||||||
increased precision, this approximation is no longer accurate
|
|
||||||
enough and :func:`identify` more correctly returns ``None``::
|
|
||||||
|
|
||||||
>>> identify(pi)
|
|
||||||
'(2**(176/117)*3**(20/117)*5**(35/39))/(7**(92/117))'
|
|
||||||
>>> mp.dps = 30
|
|
||||||
>>> identify(pi)
|
|
||||||
>>>
|
|
||||||
|
|
||||||
Numbers such as `\pi`, and simple combinations of user-defined
|
|
||||||
constants, can be identified if they are provided explicitly::
|
|
||||||
|
|
||||||
>>> identify(3*pi-2*e, ['pi', 'e'])
|
|
||||||
'(3*pi + (-2)*e)'
|
|
||||||
|
|
||||||
Here is an example using a dict of constants. Note that the
|
|
||||||
constants need not be "atomic"; :func:`identify` can just
|
|
||||||
as well express the given number in terms of expressions
|
|
||||||
given by formulas::
|
|
||||||
|
|
||||||
>>> identify(pi+e, {'a':pi+2, 'b':2*e})
|
|
||||||
'((-2) + 1*a + (1/2)*b)'
|
|
||||||
|
|
||||||
Next, we attempt some identifications with a set of base constants.
|
|
||||||
It is necessary to increase the precision a bit.
|
|
||||||
|
|
||||||
>>> mp.dps = 50
|
|
||||||
>>> base = ['sqrt(2)','pi','log(2)']
|
|
||||||
>>> identify(0.25, base)
|
|
||||||
'(1/4)'
|
|
||||||
>>> identify(3*pi + 2*sqrt(2) + 5*log(2)/7, base)
|
|
||||||
'(2*sqrt(2) + 3*pi + (5/7)*log(2))'
|
|
||||||
>>> identify(exp(pi+2), base)
|
|
||||||
'exp((2 + 1*pi))'
|
|
||||||
>>> identify(1/(3+sqrt(2)), base)
|
|
||||||
'((3/7) + (-1/7)*sqrt(2))'
|
|
||||||
>>> identify(sqrt(2)/(3*pi+4), base)
|
|
||||||
'sqrt(2)/(4 + 3*pi)'
|
|
||||||
>>> identify(5**(mpf(1)/3)*pi*log(2)**2, base)
|
|
||||||
'5**(1/3)*pi*log(2)**2'
|
|
||||||
|
|
||||||
An example of an erroneous solution being found when too low
|
|
||||||
precision is used::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
|
|
||||||
'((11/25) + (-158/75)*pi + (76/75)*e + (44/15)*sqrt(2))'
|
|
||||||
>>> mp.dps = 50
|
|
||||||
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
|
|
||||||
'1/(3*pi + (-4)*e + 2*sqrt(2))'
|
|
||||||
|
|
||||||
**Finding approximate solutions**
|
|
||||||
|
|
||||||
The tolerance ``tol`` defaults to 3/4 of the working precision.
|
|
||||||
Lowering the tolerance is useful for finding approximate matches.
|
|
||||||
We can for example try to generate approximations for pi::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> identify(pi, tol=1e-2)
|
|
||||||
'(22/7)'
|
|
||||||
>>> identify(pi, tol=1e-3)
|
|
||||||
'(355/113)'
|
|
||||||
>>> identify(pi, tol=1e-10)
|
|
||||||
'(5**(339/269))/(2**(64/269)*3**(13/269)*7**(92/269))'
|
|
||||||
|
|
||||||
With ``full=True``, and by supplying a few base constants,
|
|
||||||
``identify`` can generate almost endless lists of approximations
|
|
||||||
for any number (the output below has been truncated to show only
|
|
||||||
the first few)::
|
|
||||||
|
|
||||||
>>> for p in identify(pi, ['e', 'catalan'], tol=1e-5, full=True):
|
|
||||||
... print p
|
|
||||||
... # doctest: +ELLIPSIS
|
|
||||||
e/log((6 + (-4/3)*e))
|
|
||||||
(3**3*5*e*catalan**2)/(2*7**2)
|
|
||||||
sqrt(((-13) + 1*e + 22*catalan))
|
|
||||||
log(((-6) + 24*e + 4*catalan)/e)
|
|
||||||
exp(catalan*((-1/5) + (8/15)*e))
|
|
||||||
catalan*(6 + (-6)*e + 15*catalan)
|
|
||||||
sqrt((5 + 26*e + (-3)*catalan))/e
|
|
||||||
e*sqrt(((-27) + 2*e + 25*catalan))
|
|
||||||
log(((-1) + (-11)*e + 59*catalan))
|
|
||||||
((3/20) + (21/20)*e + (3/20)*catalan)
|
|
||||||
...
|
|
||||||
|
|
||||||
The numerical values are roughly as close to pi as permitted by the
|
|
||||||
specified tolerance:
|
|
||||||
|
|
||||||
>>> e/log(6-4*e/3)
|
|
||||||
3.14157719846001
|
|
||||||
>>> 135*e*catalan**2/98
|
|
||||||
3.14166950419369
|
|
||||||
>>> sqrt(e-13+22*catalan)
|
|
||||||
3.14158000062992
|
|
||||||
>>> log(24*e-6+4*catalan)-1
|
|
||||||
3.14158791577159
|
|
||||||
|
|
||||||
**Symbolic processing**
|
|
||||||
|
|
||||||
The output formula can be evaluated as a Python expression.
|
|
||||||
Note however that if fractions (like '2/3') are present in
|
|
||||||
the formula, Python's :func:`eval()` may erroneously perform
|
|
||||||
integer division. Note also that the output is not necessarily
|
|
||||||
in the algebraically simplest form::
|
|
||||||
|
|
||||||
>>> identify(sqrt(2))
|
|
||||||
'(sqrt(8)/2)'
|
|
||||||
|
|
||||||
As a solution to both problems, consider using SymPy's
|
|
||||||
:func:`sympify` to convert the formula into a symbolic expression.
|
|
||||||
SymPy can be used to pretty-print or further simplify the formula
|
|
||||||
symbolically::
|
|
||||||
|
|
||||||
>>> from sympy import sympify
|
|
||||||
>>> sympify(identify(sqrt(2)))
|
|
||||||
2**(1/2)
|
|
||||||
|
|
||||||
Sometimes :func:`identify` can simplify an expression further than
|
|
||||||
a symbolic algorithm::
|
|
||||||
|
|
||||||
>>> from sympy import simplify
|
|
||||||
>>> x = sympify('-1/(-3/2+(1/2)*5**(1/2))*(3/2-1/2*5**(1/2))**(1/2)')
|
|
||||||
>>> x
|
|
||||||
(3/2 - 5**(1/2)/2)**(-1/2)
|
|
||||||
>>> x = simplify(x)
|
|
||||||
>>> x
|
|
||||||
2/(6 - 2*5**(1/2))**(1/2)
|
|
||||||
>>> mp.dps = 30
|
|
||||||
>>> x = sympify(identify(x.evalf(30)))
|
|
||||||
>>> x
|
|
||||||
1/2 + 5**(1/2)/2
|
|
||||||
|
|
||||||
(In fact, this functionality is available directly in SymPy as the
|
|
||||||
function :func:`nsimplify`, which is essentially a wrapper for
|
|
||||||
:func:`identify`.)
|
|
||||||
|
|
||||||
**Miscellaneous issues and limitations**
|
|
||||||
|
|
||||||
The input `x` must be a real number. All base constants must be
|
|
||||||
positive real numbers and must not be rationals or rational linear
|
|
||||||
combinations of each other.
|
|
||||||
|
|
||||||
The worst-case computation time grows quickly with the number of
|
|
||||||
base constants. Already with 3 or 4 base constants,
|
|
||||||
:func:`identify` may require several seconds to finish. To search
|
|
||||||
for relations among a large number of constants, you should
|
|
||||||
consider using :func:`pslq` directly.
|
|
||||||
|
|
||||||
The extended transformations are applied to x, not the constants
|
|
||||||
separately. As a result, ``identify`` will for example be able to
|
|
||||||
recognize ``exp(2*pi+3)`` with ``pi`` given as a base constant, but
|
|
||||||
not ``2*exp(pi)+3``. It will be able to recognize the latter if
|
|
||||||
``exp(pi)`` is given explicitly as a base constant.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
solutions = []
|
|
||||||
|
|
||||||
def addsolution(s):
|
|
||||||
if verbose: print "Found: ", s
|
|
||||||
solutions.append(s)
|
|
||||||
|
|
||||||
x = ctx.mpf(x)
|
|
||||||
|
|
||||||
# Further along, x will be assumed positive
|
|
||||||
if x == 0:
|
|
||||||
if full: return ['0']
|
|
||||||
else: return '0'
|
|
||||||
if x < 0:
|
|
||||||
sol = ctx.identify(-x, constants, tol, maxcoeff, full, verbose)
|
|
||||||
if sol is None:
|
|
||||||
return sol
|
|
||||||
if full:
|
|
||||||
return ["-(%s)"%s for s in sol]
|
|
||||||
else:
|
|
||||||
return "-(%s)" % sol
|
|
||||||
|
|
||||||
if tol:
|
|
||||||
tol = ctx.mpf(tol)
|
|
||||||
else:
|
|
||||||
tol = ctx.eps**0.7
|
|
||||||
M = maxcoeff
|
|
||||||
|
|
||||||
if constants:
|
|
||||||
if isinstance(constants, dict):
|
|
||||||
constants = [(ctx.mpf(v), name) for (name, v) in constants.items()]
|
|
||||||
else:
|
|
||||||
namespace = dict((name, getattr(ctx,name)) for name in dir(ctx))
|
|
||||||
constants = [(eval(p, namespace), p) for p in constants]
|
|
||||||
else:
|
|
||||||
constants = []
|
|
||||||
|
|
||||||
# We always want to find at least rational terms
|
|
||||||
if 1 not in [value for (name, value) in constants]:
|
|
||||||
constants = [(ctx.mpf(1), '1')] + constants
|
|
||||||
|
|
||||||
# PSLQ with simple algebraic and functional transformations
|
|
||||||
for ft, ftn, red in transforms:
|
|
||||||
for c, cn in constants:
|
|
||||||
if red and cn == '1':
|
|
||||||
continue
|
|
||||||
t = ft(ctx,x,c)
|
|
||||||
# Prevent exponential transforms from wreaking havoc
|
|
||||||
if abs(t) > M**2 or abs(t) < tol:
|
|
||||||
continue
|
|
||||||
# Linear combination of base constants
|
|
||||||
r = ctx.pslq([t] + [a[0] for a in constants], tol, M)
|
|
||||||
s = None
|
|
||||||
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
|
|
||||||
s = pslqstring(r, constants)
|
|
||||||
# Quadratic algebraic numbers
|
|
||||||
else:
|
|
||||||
q = ctx.pslq([ctx.one, t, t**2], tol, M)
|
|
||||||
if q is not None and len(q) == 3 and q[2]:
|
|
||||||
aa, bb, cc = q
|
|
||||||
if max(abs(aa),abs(bb),abs(cc)) <= M:
|
|
||||||
s = quadraticstring(ctx,t,aa,bb,cc)
|
|
||||||
if s:
|
|
||||||
if cn == '1' and ('/$c' in ftn):
|
|
||||||
s = ftn.replace('$y', s).replace('/$c', '')
|
|
||||||
else:
|
|
||||||
s = ftn.replace('$y', s).replace('$c', cn)
|
|
||||||
addsolution(s)
|
|
||||||
if not full: return solutions[0]
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print "."
|
|
||||||
|
|
||||||
# Check for a direct multiplicative formula
|
|
||||||
if x != 1:
|
|
||||||
# Allow fractional powers of fractions
|
|
||||||
ilogs = [2,3,5,7]
|
|
||||||
# Watch out for existing fractional powers of fractions
|
|
||||||
logs = []
|
|
||||||
for a, s in constants:
|
|
||||||
if not sum(bool(ctx.findpoly(ctx.ln(a)/ctx.ln(i),1)) for i in ilogs):
|
|
||||||
logs.append((ctx.ln(a), s))
|
|
||||||
logs = [(ctx.ln(i),str(i)) for i in ilogs] + logs
|
|
||||||
r = ctx.pslq([ctx.ln(x)] + [a[0] for a in logs], tol, M)
|
|
||||||
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
|
|
||||||
addsolution(prodstring(r, logs))
|
|
||||||
if not full: return solutions[0]
|
|
||||||
|
|
||||||
if full:
|
|
||||||
return sorted(solutions, key=len)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
IdentificationMethods.pslq = pslq
|
|
||||||
IdentificationMethods.findpoly = findpoly
|
|
||||||
IdentificationMethods.identify = identify
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
from libmpf import (prec_to_dps, dps_to_prec, repr_dps,
|
|
||||||
round_down, round_up, round_floor, round_ceiling, round_nearest,
|
|
||||||
to_pickable, from_pickable, ComplexResult,
|
|
||||||
fzero, fnzero, fone, fnone, ftwo, ften, fhalf, fnan, finf, fninf,
|
|
||||||
math_float_inf, round_int, normalize, normalize1,
|
|
||||||
from_man_exp, from_int, to_man_exp, to_int, mpf_ceil, mpf_floor,
|
|
||||||
from_float, to_float, from_rational, to_rational, to_fixed,
|
|
||||||
mpf_rand, mpf_eq, mpf_hash, mpf_cmp, mpf_lt, mpf_le, mpf_gt, mpf_ge,
|
|
||||||
mpf_pos, mpf_neg, mpf_abs, mpf_sign, mpf_add, mpf_sub, mpf_sum,
|
|
||||||
mpf_mul, mpf_mul_int, mpf_shift, mpf_frexp,
|
|
||||||
mpf_div, mpf_rdiv_int, mpf_mod, mpf_pow_int,
|
|
||||||
mpf_perturb,
|
|
||||||
to_digits_exp, to_str, str_to_man_exp, from_str, from_bstr, to_bstr,
|
|
||||||
mpf_sqrt, mpf_hypot)
|
|
||||||
|
|
||||||
from libmpc import (mpc_one, mpc_zero, mpc_two, mpc_half,
|
|
||||||
mpc_is_inf, mpc_is_infnan, mpc_to_str, mpc_to_complex, mpc_hash,
|
|
||||||
mpc_conjugate, mpc_is_nonzero, mpc_add, mpc_add_mpf,
|
|
||||||
mpc_sub, mpc_sub_mpf, mpc_pos, mpc_neg, mpc_shift, mpc_abs,
|
|
||||||
mpc_arg, mpc_floor, mpc_ceil, mpc_mul, mpc_square,
|
|
||||||
mpc_mul_mpf, mpc_mul_imag_mpf, mpc_mul_int,
|
|
||||||
mpc_div, mpc_div_mpf, mpc_reciprocal, mpc_mpf_div,
|
|
||||||
complex_int_pow, mpc_pow, mpc_pow_mpf, mpc_pow_int,
|
|
||||||
mpc_sqrt, mpc_nthroot, mpc_cbrt, mpc_exp, mpc_log, mpc_cos, mpc_sin,
|
|
||||||
mpc_tan, mpc_cos_pi, mpc_sin_pi, mpc_cosh, mpc_sinh, mpc_tanh,
|
|
||||||
mpc_atan, mpc_acos, mpc_asin, mpc_asinh, mpc_acosh, mpc_atanh,
|
|
||||||
mpc_fibonacci, mpf_expj, mpf_expjpi, mpc_expj, mpc_expjpi)
|
|
||||||
|
|
||||||
from libelefun import (ln2_fixed, mpf_ln2, ln10_fixed, mpf_ln10,
|
|
||||||
pi_fixed, mpf_pi, e_fixed, mpf_e, phi_fixed, mpf_phi,
|
|
||||||
degree_fixed, mpf_degree,
|
|
||||||
mpf_pow, mpf_nthroot, mpf_cbrt, log_int_fixed, agm_fixed,
|
|
||||||
mpf_log, mpf_log_hypot, mpf_exp, mpf_cos_sin, mpf_cos, mpf_sin, mpf_tan,
|
|
||||||
mpf_cos_sin_pi, mpf_cos_pi, mpf_sin_pi, mpf_cosh_sinh,
|
|
||||||
mpf_cosh, mpf_sinh, mpf_tanh, mpf_atan, mpf_atan2, mpf_asin,
|
|
||||||
mpf_acos, mpf_asinh, mpf_acosh, mpf_atanh, mpf_fibonacci)
|
|
||||||
|
|
||||||
from libhyper import (NoConvergence, make_hyp_summator,
|
|
||||||
mpf_erf, mpf_erfc, mpf_ei, mpc_ei, mpf_e1, mpc_e1, mpf_expint,
|
|
||||||
mpf_ci_si, mpf_ci, mpf_si, mpc_ci, mpc_si, mpf_besseljn,
|
|
||||||
mpc_besseljn, mpf_agm, mpf_agm1, mpc_agm, mpc_agm1,
|
|
||||||
mpf_ellipk, mpc_ellipk, mpf_ellipe, mpc_ellipe)
|
|
||||||
|
|
||||||
from gammazeta import (catalan_fixed, mpf_catalan,
|
|
||||||
khinchin_fixed, mpf_khinchin, glaisher_fixed, mpf_glaisher,
|
|
||||||
apery_fixed, mpf_apery, euler_fixed, mpf_euler, mertens_fixed,
|
|
||||||
mpf_mertens, twinprime_fixed, mpf_twinprime,
|
|
||||||
mpf_bernoulli, bernfrac, mpf_gamma_int,
|
|
||||||
mpf_factorial, mpc_factorial, mpf_gamma, mpc_gamma,
|
|
||||||
mpf_harmonic, mpc_harmonic, mpf_psi0, mpc_psi0,
|
|
||||||
mpf_psi, mpc_psi, mpf_zeta_int, mpf_zeta, mpc_zeta,
|
|
||||||
mpf_altzeta, mpc_altzeta, mpf_zetasum, mpc_zetasum)
|
|
||||||
|
|
||||||
from libmpi import (mpi_str, mpi_add, mpi_sub, mpi_delta, mpi_mid,
|
|
||||||
mpi_pos, mpi_neg, mpi_abs, mpi_mul, mpi_div, mpi_exp,
|
|
||||||
mpi_log, mpi_sqrt, mpi_pow_int, mpi_pow, mpi_cos_sin,
|
|
||||||
mpi_cos, mpi_sin, mpi_tan, mpi_cot)
|
|
||||||
|
|
||||||
from libintmath import (trailing, bitcount, numeral, bin_to_radix,
|
|
||||||
isqrt, isqrt_small, isqrt_fast, sqrt_fixed, sqrtrem, ifib, ifac,
|
|
||||||
list_primes, moebius, gcd, eulernum)
|
|
||||||
|
|
||||||
from backend import (gmpy, sage, BACKEND, STRICT, MPZ, MPZ_TYPE,
|
|
||||||
MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_THREE, MPZ_FIVE, int_types)
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
# Support GMPY for high-speed large integer arithmetic. #
|
|
||||||
# #
|
|
||||||
# To allow an external module to handle arithmetic, we need to make sure #
|
|
||||||
# that all high-precision variables are declared of the correct type. MPZ #
|
|
||||||
# is the constructor for the high-precision type. It defaults to Python's #
|
|
||||||
# long type but can be assinged another type, typically gmpy.mpz. #
|
|
||||||
# #
|
|
||||||
# MPZ must be used for the mantissa component of an mpf and must be used #
|
|
||||||
# for internal fixed-point operations. #
|
|
||||||
# #
|
|
||||||
# Side-effects #
|
|
||||||
# 1) "is" cannot be used to test for special values. Must use "==". #
|
|
||||||
# 2) There are bugs in GMPY prior to v1.02 so we must use v1.03 or later. #
|
|
||||||
#----------------------------------------------------------------------------#
|
|
||||||
|
|
||||||
# So we can import it from this module
|
|
||||||
gmpy = None
|
|
||||||
sage = None
|
|
||||||
sage_utils = None
|
|
||||||
|
|
||||||
BACKEND = 'python'
|
|
||||||
MPZ = long
|
|
||||||
|
|
||||||
if 'MPMATH_NOGMPY' not in os.environ:
|
|
||||||
try:
|
|
||||||
import gmpy
|
|
||||||
if gmpy.version() >= '1.03':
|
|
||||||
BACKEND = 'gmpy'
|
|
||||||
MPZ = gmpy.mpz
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if 'MPMATH_NOSAGE' not in os.environ:
|
|
||||||
try:
|
|
||||||
import sage.all
|
|
||||||
import sage.libs.mpmath.utils as _sage_utils
|
|
||||||
sage = sage.all
|
|
||||||
sage_utils = _sage_utils
|
|
||||||
BACKEND = 'sage'
|
|
||||||
MPZ = sage.Integer
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if 'MPMATH_STRICT' in os.environ:
|
|
||||||
STRICT = True
|
|
||||||
else:
|
|
||||||
STRICT = False
|
|
||||||
|
|
||||||
MPZ_TYPE = type(MPZ(0))
|
|
||||||
MPZ_ZERO = MPZ(0)
|
|
||||||
MPZ_ONE = MPZ(1)
|
|
||||||
MPZ_TWO = MPZ(2)
|
|
||||||
MPZ_THREE = MPZ(3)
|
|
||||||
MPZ_FIVE = MPZ(5)
|
|
||||||
|
|
||||||
if BACKEND == 'python':
|
|
||||||
int_types = (int, long)
|
|
||||||
else:
|
|
||||||
int_types = (int, long, MPZ_TYPE)
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,461 +0,0 @@
|
||||||
"""
|
|
||||||
Utility functions for integer math.
|
|
||||||
|
|
||||||
TODO: rename, cleanup, perhaps move the gmpy wrapper code
|
|
||||||
here from settings.py
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from bisect import bisect
|
|
||||||
|
|
||||||
from backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO
|
|
||||||
|
|
||||||
def giant_steps(start, target, n=2):
|
|
||||||
"""
|
|
||||||
Return a list of integers ~=
|
|
||||||
|
|
||||||
[start, n*start, ..., target/n^2, target/n, target]
|
|
||||||
|
|
||||||
but conservatively rounded so that the quotient between two
|
|
||||||
successive elements is actually slightly less than n.
|
|
||||||
|
|
||||||
With n = 2, this describes suitable precision steps for a
|
|
||||||
quadratically convergent algorithm such as Newton's method;
|
|
||||||
with n = 3 steps for cubic convergence (Halley's method), etc.
|
|
||||||
|
|
||||||
>>> giant_steps(50,1000)
|
|
||||||
[66, 128, 253, 502, 1000]
|
|
||||||
>>> giant_steps(50,1000,4)
|
|
||||||
[65, 252, 1000]
|
|
||||||
|
|
||||||
"""
|
|
||||||
L = [target]
|
|
||||||
while L[-1] > start*n:
|
|
||||||
L = L + [L[-1]//n + 2]
|
|
||||||
return L[::-1]
|
|
||||||
|
|
||||||
def rshift(x, n):
|
|
||||||
"""For an integer x, calculate x >> n with the fastest (floor)
|
|
||||||
rounding. Unlike the plain Python expression (x >> n), n is
|
|
||||||
allowed to be negative, in which case a left shift is performed."""
|
|
||||||
if n >= 0: return x >> n
|
|
||||||
else: return x << (-n)
|
|
||||||
|
|
||||||
def lshift(x, n):
|
|
||||||
"""For an integer x, calculate x << n. Unlike the plain Python
|
|
||||||
expression (x << n), n is allowed to be negative, in which case a
|
|
||||||
right shift with default (floor) rounding is performed."""
|
|
||||||
if n >= 0: return x << n
|
|
||||||
else: return x >> (-n)
|
|
||||||
|
|
||||||
if BACKEND == 'sage':
|
|
||||||
import operator
|
|
||||||
rshift = operator.rshift
|
|
||||||
lshift = operator.lshift
|
|
||||||
|
|
||||||
def python_trailing(n):
|
|
||||||
"""Count the number of trailing zero bits in abs(n)."""
|
|
||||||
if not n:
|
|
||||||
return 0
|
|
||||||
t = 0
|
|
||||||
while not n & 1:
|
|
||||||
n >>= 1
|
|
||||||
t += 1
|
|
||||||
return t
|
|
||||||
|
|
||||||
def gmpy_trailing(n):
|
|
||||||
"""Count the number of trailing zero bits in abs(n) using gmpy."""
|
|
||||||
if n: return MPZ(n).scan1()
|
|
||||||
else: return 0
|
|
||||||
|
|
||||||
# Small powers of 2
|
|
||||||
powers = [1<<_ for _ in range(300)]
|
|
||||||
|
|
||||||
def python_bitcount(n):
|
|
||||||
"""Calculate bit size of the nonnegative integer n."""
|
|
||||||
bc = bisect(powers, n)
|
|
||||||
if bc != 300:
|
|
||||||
return bc
|
|
||||||
bc = int(math.log(n, 2)) - 4
|
|
||||||
return bc + bctable[n>>bc]
|
|
||||||
|
|
||||||
def gmpy_bitcount(n):
|
|
||||||
"""Calculate bit size of the nonnegative integer n."""
|
|
||||||
if n: return MPZ(n).numdigits(2)
|
|
||||||
else: return 0
|
|
||||||
|
|
||||||
#def sage_bitcount(n):
|
|
||||||
# if n: return MPZ(n).nbits()
|
|
||||||
# else: return 0
|
|
||||||
|
|
||||||
def sage_trailing(n):
|
|
||||||
return MPZ(n).trailing_zero_bits()
|
|
||||||
|
|
||||||
if BACKEND == 'gmpy':
|
|
||||||
bitcount = gmpy_bitcount
|
|
||||||
trailing = gmpy_trailing
|
|
||||||
elif BACKEND == 'sage':
|
|
||||||
sage_bitcount = sage_utils.bitcount
|
|
||||||
bitcount = sage_bitcount
|
|
||||||
trailing = sage_trailing
|
|
||||||
else:
|
|
||||||
bitcount = python_bitcount
|
|
||||||
trailing = python_trailing
|
|
||||||
|
|
||||||
if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy):
|
|
||||||
bitcount = gmpy.bit_length
|
|
||||||
|
|
||||||
# Used to avoid slow function calls as far as possible
|
|
||||||
trailtable = map(trailing, range(256))
|
|
||||||
bctable = map(bitcount, range(1024))
|
|
||||||
|
|
||||||
# TODO: speed up for bases 2, 4, 8, 16, ...
|
|
||||||
|
|
||||||
def bin_to_radix(x, xbits, base, bdigits):
|
|
||||||
"""Changes radix of a fixed-point number; i.e., converts
|
|
||||||
x * 2**xbits to floor(x * 10**bdigits)."""
|
|
||||||
return x * (MPZ(base)**bdigits) >> xbits
|
|
||||||
|
|
||||||
stddigits = '0123456789abcdefghijklmnopqrstuvwxyz'
|
|
||||||
|
|
||||||
def small_numeral(n, base=10, digits=stddigits):
|
|
||||||
"""Return the string numeral of a positive integer in an arbitrary
|
|
||||||
base. Most efficient for small input."""
|
|
||||||
if base == 10:
|
|
||||||
return str(n)
|
|
||||||
digs = []
|
|
||||||
while n:
|
|
||||||
n, digit = divmod(n, base)
|
|
||||||
digs.append(digits[digit])
|
|
||||||
return "".join(digs[::-1])
|
|
||||||
|
|
||||||
def numeral_python(n, base=10, size=0, digits=stddigits):
|
|
||||||
"""Represent the integer n as a string of digits in the given base.
|
|
||||||
Recursive division is used to make this function about 3x faster
|
|
||||||
than Python's str() for converting integers to decimal strings.
|
|
||||||
|
|
||||||
The 'size' parameters specifies the number of digits in n; this
|
|
||||||
number is only used to determine splitting points and need not be
|
|
||||||
exact."""
|
|
||||||
if n <= 0:
|
|
||||||
if not n:
|
|
||||||
return "0"
|
|
||||||
return "-" + numeral(-n, base, size, digits)
|
|
||||||
# Fast enough to do directly
|
|
||||||
if size < 250:
|
|
||||||
return small_numeral(n, base, digits)
|
|
||||||
# Divide in half
|
|
||||||
half = (size // 2) + (size & 1)
|
|
||||||
A, B = divmod(n, base**half)
|
|
||||||
ad = numeral(A, base, half, digits)
|
|
||||||
bd = numeral(B, base, half, digits).rjust(half, "0")
|
|
||||||
return ad + bd
|
|
||||||
|
|
||||||
def numeral_gmpy(n, base=10, size=0, digits=stddigits):
|
|
||||||
"""Represent the integer n as a string of digits in the given base.
|
|
||||||
Recursive division is used to make this function about 3x faster
|
|
||||||
than Python's str() for converting integers to decimal strings.
|
|
||||||
|
|
||||||
The 'size' parameters specifies the number of digits in n; this
|
|
||||||
number is only used to determine splitting points and need not be
|
|
||||||
exact."""
|
|
||||||
if n < 0:
|
|
||||||
return "-" + numeral(-n, base, size, digits)
|
|
||||||
# gmpy.digits() may cause a segmentation fault when trying to convert
|
|
||||||
# extremely large values to a string. The size limit may need to be
|
|
||||||
# adjusted on some platforms, but 1500000 works on Windows and Linux.
|
|
||||||
if size < 1500000:
|
|
||||||
return gmpy.digits(n, base)
|
|
||||||
# Divide in half
|
|
||||||
half = (size // 2) + (size & 1)
|
|
||||||
A, B = divmod(n, MPZ(base)**half)
|
|
||||||
ad = numeral(A, base, half, digits)
|
|
||||||
bd = numeral(B, base, half, digits).rjust(half, "0")
|
|
||||||
return ad + bd
|
|
||||||
|
|
||||||
if BACKEND == "gmpy":
|
|
||||||
numeral = numeral_gmpy
|
|
||||||
else:
|
|
||||||
numeral = numeral_python
|
|
||||||
|
|
||||||
_1_800 = 1<<800
|
|
||||||
_1_600 = 1<<600
|
|
||||||
_1_400 = 1<<400
|
|
||||||
_1_200 = 1<<200
|
|
||||||
_1_100 = 1<<100
|
|
||||||
_1_50 = 1<<50
|
|
||||||
|
|
||||||
def isqrt_small_python(x):
|
|
||||||
"""
|
|
||||||
Correctly (floor) rounded integer square root, using
|
|
||||||
division. Fast up to ~200 digits.
|
|
||||||
"""
|
|
||||||
if not x:
|
|
||||||
return x
|
|
||||||
if x < _1_800:
|
|
||||||
# Exact with IEEE double precision arithmetic
|
|
||||||
if x < _1_50:
|
|
||||||
return int(x**0.5)
|
|
||||||
# Initial estimate can be any integer >= the true root; round up
|
|
||||||
r = int(x**0.5 * 1.00000000000001) + 1
|
|
||||||
else:
|
|
||||||
bc = bitcount(x)
|
|
||||||
n = bc//2
|
|
||||||
r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
|
|
||||||
# The following iteration now precisely computes floor(sqrt(x))
|
|
||||||
# See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
|
|
||||||
# Perspective"
|
|
||||||
while 1:
|
|
||||||
y = (r+x//r)>>1
|
|
||||||
if y >= r:
|
|
||||||
return r
|
|
||||||
r = y
|
|
||||||
|
|
||||||
def isqrt_fast_python(x):
|
|
||||||
"""
|
|
||||||
Fast approximate integer square root, computed using division-free
|
|
||||||
Newton iteration for large x. For random integers the result is almost
|
|
||||||
always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
|
|
||||||
0.1% probability. If x is very close to an exact square, the answer is
|
|
||||||
1 ulp wrong with high probability.
|
|
||||||
|
|
||||||
With 0 guard bits, the largest error over a set of 10^5 random
|
|
||||||
inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
|
|
||||||
almost certainly guarantees a max 1 ulp error.
|
|
||||||
"""
|
|
||||||
# Use direct division-based iteration if sqrt(x) < 2^400
|
|
||||||
# Assume floating-point square root accurate to within 1 ulp, then:
|
|
||||||
# 0 Newton iterations good to 52 bits
|
|
||||||
# 1 Newton iterations good to 104 bits
|
|
||||||
# 2 Newton iterations good to 208 bits
|
|
||||||
# 3 Newton iterations good to 416 bits
|
|
||||||
if x < _1_800:
|
|
||||||
y = int(x**0.5)
|
|
||||||
if x >= _1_100:
|
|
||||||
y = (y + x//y) >> 1
|
|
||||||
if x >= _1_200:
|
|
||||||
y = (y + x//y) >> 1
|
|
||||||
if x >= _1_400:
|
|
||||||
y = (y + x//y) >> 1
|
|
||||||
return y
|
|
||||||
bc = bitcount(x)
|
|
||||||
guard_bits = 10
|
|
||||||
x <<= 2*guard_bits
|
|
||||||
bc += 2*guard_bits
|
|
||||||
bc += (bc&1)
|
|
||||||
hbc = bc//2
|
|
||||||
startprec = min(50, hbc)
|
|
||||||
# Newton iteration for 1/sqrt(x), with floating-point starting value
|
|
||||||
r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
|
|
||||||
pp = startprec
|
|
||||||
for p in giant_steps(startprec, hbc):
|
|
||||||
# r**2, scaled from real size 2**(-bc) to 2**p
|
|
||||||
r2 = (r*r) >> (2*pp - p)
|
|
||||||
# x*r**2, scaled from real size ~1.0 to 2**p
|
|
||||||
xr2 = ((x >> (bc-p)) * r2) >> p
|
|
||||||
# New value of r, scaled from real size 2**(-bc/2) to 2**p
|
|
||||||
r = (r * ((3<<p) - xr2)) >> (pp+1)
|
|
||||||
pp = p
|
|
||||||
# (1/sqrt(x))*x = sqrt(x)
|
|
||||||
return (r*(x>>hbc)) >> (p+guard_bits)
|
|
||||||
|
|
||||||
def sqrtrem_python(x):
|
|
||||||
"""Correctly rounded integer (floor) square root with remainder."""
|
|
||||||
# to check cutoff:
|
|
||||||
# plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
|
|
||||||
if x < _1_600:
|
|
||||||
y = isqrt_small_python(x)
|
|
||||||
return y, x - y*y
|
|
||||||
y = isqrt_fast_python(x) + 1
|
|
||||||
rem = x - y*y
|
|
||||||
# Correct remainder
|
|
||||||
while rem < 0:
|
|
||||||
y -= 1
|
|
||||||
rem += (1+2*y)
|
|
||||||
else:
|
|
||||||
if rem:
|
|
||||||
while rem > 2*(1+y):
|
|
||||||
y += 1
|
|
||||||
rem -= (1+2*y)
|
|
||||||
return y, rem
|
|
||||||
|
|
||||||
def isqrt_python(x):
|
|
||||||
"""Integer square root with correct (floor) rounding."""
|
|
||||||
return sqrtrem_python(x)[0]
|
|
||||||
|
|
||||||
def sqrt_fixed(x, prec):
|
|
||||||
return isqrt_fast(x<<prec)
|
|
||||||
|
|
||||||
sqrt_fixed2 = sqrt_fixed
|
|
||||||
|
|
||||||
if BACKEND == 'gmpy':
|
|
||||||
isqrt_small = isqrt_fast = isqrt = gmpy.sqrt
|
|
||||||
sqrtrem = gmpy.sqrtrem
|
|
||||||
elif BACKEND == 'sage':
|
|
||||||
isqrt_small = isqrt_fast = isqrt = lambda n: MPZ(n).isqrt()
|
|
||||||
sqrtrem = lambda n: MPZ(n).sqrtrem()
|
|
||||||
else:
|
|
||||||
isqrt_small = isqrt_small_python
|
|
||||||
isqrt_fast = isqrt_fast_python
|
|
||||||
isqrt = isqrt_python
|
|
||||||
sqrtrem = sqrtrem_python
|
|
||||||
|
|
||||||
|
|
||||||
def ifib(n, _cache={}):
|
|
||||||
"""Computes the nth Fibonacci number as an integer, for
|
|
||||||
integer n."""
|
|
||||||
if n < 0:
|
|
||||||
return (-1)**(-n+1) * ifib(-n)
|
|
||||||
if n in _cache:
|
|
||||||
return _cache[n]
|
|
||||||
m = n
|
|
||||||
# Use Dijkstra's logarithmic algorithm
|
|
||||||
# The following implementation is basically equivalent to
|
|
||||||
# http://en.literateprograms.org/Fibonacci_numbers_(Scheme)
|
|
||||||
a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE
|
|
||||||
while n:
|
|
||||||
if n & 1:
|
|
||||||
aq = a*q
|
|
||||||
a, b = b*q+aq+a*p, b*p+aq
|
|
||||||
n -= 1
|
|
||||||
else:
|
|
||||||
qq = q*q
|
|
||||||
p, q = p*p+qq, qq+2*p*q
|
|
||||||
n >>= 1
|
|
||||||
if m < 250:
|
|
||||||
_cache[m] = b
|
|
||||||
return b
|
|
||||||
|
|
||||||
MAX_FACTORIAL_CACHE = 1000
|
|
||||||
|
|
||||||
def ifac(n, memo={0:1, 1:1}):
|
|
||||||
"""Return n factorial (for integers n >= 0 only)."""
|
|
||||||
f = memo.get(n)
|
|
||||||
if f:
|
|
||||||
return f
|
|
||||||
k = len(memo)
|
|
||||||
p = memo[k-1]
|
|
||||||
MAX = MAX_FACTORIAL_CACHE
|
|
||||||
while k <= n:
|
|
||||||
p *= k
|
|
||||||
if k <= MAX:
|
|
||||||
memo[k] = p
|
|
||||||
k += 1
|
|
||||||
return p
|
|
||||||
|
|
||||||
if BACKEND == 'gmpy':
|
|
||||||
ifac = gmpy.fac
|
|
||||||
elif BACKEND == 'sage':
|
|
||||||
ifac = lambda n: int(sage.factorial(n))
|
|
||||||
ifib = sage.fibonacci
|
|
||||||
|
|
||||||
def list_primes(n):
|
|
||||||
n = n + 1
|
|
||||||
sieve = range(n)
|
|
||||||
sieve[:2] = [0, 0]
|
|
||||||
for i in xrange(2, int(n**0.5)+1):
|
|
||||||
if sieve[i]:
|
|
||||||
for j in xrange(i**2, n, i):
|
|
||||||
sieve[j] = 0
|
|
||||||
return [p for p in sieve if p]
|
|
||||||
|
|
||||||
if BACKEND == 'sage':
|
|
||||||
def list_primes(n):
|
|
||||||
return list(sage.primes(n+1))
|
|
||||||
|
|
||||||
def moebius(n):
|
|
||||||
"""
|
|
||||||
Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n`
|
|
||||||
is a product of `k` distinct primes and `mu(n) = 0` otherwise.
|
|
||||||
|
|
||||||
TODO: speed up using factorization
|
|
||||||
"""
|
|
||||||
n = abs(int(n))
|
|
||||||
if n < 2:
|
|
||||||
return n
|
|
||||||
factors = []
|
|
||||||
for p in xrange(2, n+1):
|
|
||||||
if not (n % p):
|
|
||||||
if not (n % p**2):
|
|
||||||
return 0
|
|
||||||
if not sum(p % f for f in factors):
|
|
||||||
factors.append(p)
|
|
||||||
return (-1)**len(factors)
|
|
||||||
|
|
||||||
def gcd(*args):
|
|
||||||
a = 0
|
|
||||||
for b in args:
|
|
||||||
if a:
|
|
||||||
while b:
|
|
||||||
a, b = b, a % b
|
|
||||||
else:
|
|
||||||
a = b
|
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
# Comment by Juan Arias de Reyna:
|
|
||||||
#
|
|
||||||
# I learn this method to compute EulerE[2n] from van de Lune.
|
|
||||||
#
|
|
||||||
# We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1)
|
|
||||||
#
|
|
||||||
# where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies
|
|
||||||
#
|
|
||||||
# a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0
|
|
||||||
#
|
|
||||||
# a(n,j) = a(n-1,j) when n+j is even
|
|
||||||
# a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# But we can use only one array unidimensional a(j) since to compute
|
|
||||||
# a(n,j) we only need to know a(n-1,k) where k and j are of different parity
|
|
||||||
# and we have not to conserve the used values.
|
|
||||||
#
|
|
||||||
# We cached up the values of Euler numbers to sufficiently high order.
|
|
||||||
#
|
|
||||||
# Important Observation: If we pretend to use the numbers
|
|
||||||
# EulerE[1], EulerE[2], ... , EulerE[n]
|
|
||||||
# it is convenient to compute first EulerE[n], since the algorithm
|
|
||||||
# computes first all
|
|
||||||
# the previous ones, and keeps them in the CACHE
|
|
||||||
|
|
||||||
MAX_EULER_CACHE = 500
|
|
||||||
|
|
||||||
def eulernum(m, _cache={0:MPZ_ONE}):
|
|
||||||
r"""
|
|
||||||
Computes the Euler numbers `E(n)`, which can be defined as
|
|
||||||
coefficients of the Taylor expansion of `1/cosh x`:
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> [int(eulernum(n)) for n in range(11)]
|
|
||||||
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
|
|
||||||
>>> [int(eulernum(n)) for n in range(11)] # test cache
|
|
||||||
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
|
|
||||||
|
|
||||||
"""
|
|
||||||
# for odd m > 1, the Euler numbers are zero
|
|
||||||
if m & 1:
|
|
||||||
return MPZ_ZERO
|
|
||||||
f = _cache.get(m)
|
|
||||||
if f:
|
|
||||||
return f
|
|
||||||
MAX = MAX_EULER_CACHE
|
|
||||||
n = m
|
|
||||||
a = map(MPZ, [0,0,1,0,0,0])
|
|
||||||
for n in range(1, m+1):
|
|
||||||
for j in range(n+1, -1, -2):
|
|
||||||
a[j+1] = (j-1)*a[j] + (j+1)*a[j+2]
|
|
||||||
a.append(0)
|
|
||||||
suma = 0
|
|
||||||
for k in range(n+1, -1, -2):
|
|
||||||
suma += a[k+1]
|
|
||||||
if n <= MAX:
|
|
||||||
_cache[n] = ((-1)**(n//2))*(suma // 2**n)
|
|
||||||
if n == m:
|
|
||||||
return ((-1)**(n//2))*suma // 2**n
|
|
||||||
|
|
@ -1,754 +0,0 @@
|
||||||
"""
|
|
||||||
Low-level functions for complex arithmetic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_TWO
|
|
||||||
|
|
||||||
from libmpf import (\
|
|
||||||
round_floor, round_ceiling, round_down, round_up,
|
|
||||||
round_nearest, round_fast, bitcount,
|
|
||||||
bctable, normalize, normalize1, reciprocal_rnd, rshift, lshift, giant_steps,
|
|
||||||
negative_rnd,
|
|
||||||
to_str, to_fixed, from_man_exp, from_float, to_float, from_int, to_int,
|
|
||||||
fzero, fone, ftwo, fhalf, finf, fninf, fnan, fnone,
|
|
||||||
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul,
|
|
||||||
mpf_div, mpf_mul_int, mpf_shift, mpf_sqrt, mpf_hypot,
|
|
||||||
mpf_rdiv_int, mpf_floor, mpf_ceil,
|
|
||||||
mpf_sign,
|
|
||||||
ComplexResult
|
|
||||||
)
|
|
||||||
|
|
||||||
from libelefun import (\
|
|
||||||
mpf_pi, mpf_exp, mpf_log, mpf_cos_sin, mpf_cosh_sinh, mpf_tan, mpf_pow_int,
|
|
||||||
mpf_log_hypot,
|
|
||||||
mpf_cos_sin_pi, mpf_phi,
|
|
||||||
mpf_atan, mpf_atan2, mpf_cosh, mpf_sinh, mpf_tanh,
|
|
||||||
mpf_asin, mpf_acos, mpf_acosh, mpf_nthroot, mpf_fibonacci
|
|
||||||
)
|
|
||||||
|
|
||||||
# An mpc value is a (real, imag) tuple
|
|
||||||
mpc_one = fone, fzero
|
|
||||||
mpc_zero = fzero, fzero
|
|
||||||
mpc_two = ftwo, fzero
|
|
||||||
mpc_half = (fhalf, fzero)
|
|
||||||
|
|
||||||
_infs = (finf, fninf)
|
|
||||||
_infs_nan = (finf, fninf, fnan)
|
|
||||||
|
|
||||||
def mpc_is_inf(z):
|
|
||||||
"""Check if either real or imaginary part is infinite"""
|
|
||||||
re, im = z
|
|
||||||
if re in _infs: return True
|
|
||||||
if im in _infs: return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def mpc_is_infnan(z):
|
|
||||||
"""Check if either real or imaginary part is infinite or nan"""
|
|
||||||
re, im = z
|
|
||||||
if re in _infs_nan: return True
|
|
||||||
if im in _infs_nan: return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def mpc_to_str(z, dps, **kwargs):
|
|
||||||
re, im = z
|
|
||||||
rs = to_str(re, dps)
|
|
||||||
if im[0]:
|
|
||||||
return rs + " - " + to_str(mpf_neg(im), dps, **kwargs) + "j"
|
|
||||||
else:
|
|
||||||
return rs + " + " + to_str(im, dps, **kwargs) + "j"
|
|
||||||
|
|
||||||
def mpc_to_complex(z, strict=False):
|
|
||||||
re, im = z
|
|
||||||
return complex(to_float(re, strict), to_float(im, strict))
|
|
||||||
|
|
||||||
def mpc_hash(z):
|
|
||||||
try:
|
|
||||||
return hash(mpc_to_complex(z, strict=True))
|
|
||||||
except OverflowError:
|
|
||||||
return hash(z)
|
|
||||||
|
|
||||||
def mpc_conjugate(z, prec, rnd=round_fast):
|
|
||||||
re, im = z
|
|
||||||
return re, mpf_neg(im, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_is_nonzero(z):
|
|
||||||
return z != mpc_zero
|
|
||||||
|
|
||||||
def mpc_add(z, w, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
c, d = w
|
|
||||||
return mpf_add(a, c, prec, rnd), mpf_add(b, d, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_add_mpf(z, x, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
return mpf_add(a, x, prec, rnd), b
|
|
||||||
|
|
||||||
def mpc_sub(z, w, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
c, d = w
|
|
||||||
return mpf_sub(a, c, prec, rnd), mpf_sub(b, d, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_sub_mpf(z, p, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
return mpf_sub(a, p, prec, rnd), b
|
|
||||||
|
|
||||||
def mpc_pos(z, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
return mpf_pos(a, prec, rnd), mpf_pos(b, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_neg(z, prec=None, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
return mpf_neg(a, prec, rnd), mpf_neg(b, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_shift(z, n):
|
|
||||||
a, b = z
|
|
||||||
return mpf_shift(a, n), mpf_shift(b, n)
|
|
||||||
|
|
||||||
def mpc_abs(z, prec, rnd=round_fast):
|
|
||||||
"""Absolute value of a complex number, |a+bi|.
|
|
||||||
Returns an mpf value."""
|
|
||||||
a, b = z
|
|
||||||
return mpf_hypot(a, b, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_arg(z, prec, rnd=round_fast):
|
|
||||||
"""Argument of a complex number. Returns an mpf value."""
|
|
||||||
a, b = z
|
|
||||||
return mpf_atan2(b, a, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_floor(z, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
return mpf_floor(a, prec, rnd), mpf_floor(b, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_ceil(z, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
return mpf_ceil(a, prec, rnd), mpf_ceil(b, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_mul(z, w, prec, rnd=round_fast):
|
|
||||||
"""
|
|
||||||
Complex multiplication.
|
|
||||||
|
|
||||||
Returns the real and imaginary part of (a+bi)*(c+di), rounded to
|
|
||||||
the specified precision. The rounding mode applies to the real and
|
|
||||||
imaginary parts separately.
|
|
||||||
"""
|
|
||||||
a, b = z
|
|
||||||
c, d = w
|
|
||||||
p = mpf_mul(a, c)
|
|
||||||
q = mpf_mul(b, d)
|
|
||||||
r = mpf_mul(a, d)
|
|
||||||
s = mpf_mul(b, c)
|
|
||||||
re = mpf_sub(p, q, prec, rnd)
|
|
||||||
im = mpf_add(r, s, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_square(z, prec, rnd=round_fast):
|
|
||||||
# (a+b*I)**2 == a**2 - b**2 + 2*I*a*b
|
|
||||||
a, b = z
|
|
||||||
p = mpf_mul(a,a)
|
|
||||||
q = mpf_mul(b,b)
|
|
||||||
r = mpf_mul(a,b, prec, rnd)
|
|
||||||
re = mpf_sub(p, q, prec, rnd)
|
|
||||||
im = mpf_shift(r, 1)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_mul_mpf(z, p, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
re = mpf_mul(a, p, prec, rnd)
|
|
||||||
im = mpf_mul(b, p, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_mul_imag_mpf(z, x, prec, rnd=round_fast):
|
|
||||||
"""
|
|
||||||
Multiply the mpc value z by I*x where x is an mpf value.
|
|
||||||
"""
|
|
||||||
a, b = z
|
|
||||||
re = mpf_neg(mpf_mul(b, x, prec, rnd))
|
|
||||||
im = mpf_mul(a, x, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_mul_int(z, n, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
re = mpf_mul_int(a, n, prec, rnd)
|
|
||||||
im = mpf_mul_int(b, n, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_div(z, w, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
c, d = w
|
|
||||||
wp = prec + 10
|
|
||||||
# mag = c*c + d*d
|
|
||||||
mag = mpf_add(mpf_mul(c, c), mpf_mul(d, d), wp)
|
|
||||||
# (a*c+b*d)/mag, (b*c-a*d)/mag
|
|
||||||
t = mpf_add(mpf_mul(a,c), mpf_mul(b,d), wp)
|
|
||||||
u = mpf_sub(mpf_mul(b,c), mpf_mul(a,d), wp)
|
|
||||||
return mpf_div(t,mag,prec,rnd), mpf_div(u,mag,prec,rnd)
|
|
||||||
|
|
||||||
def mpc_div_mpf(z, p, prec, rnd=round_fast):
|
|
||||||
"""Calculate z/p where p is real"""
|
|
||||||
a, b = z
|
|
||||||
re = mpf_div(a, p, prec, rnd)
|
|
||||||
im = mpf_div(b, p, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_reciprocal(z, prec, rnd=round_fast):
|
|
||||||
"""Calculate 1/z efficiently"""
|
|
||||||
a, b = z
|
|
||||||
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b),prec+10)
|
|
||||||
re = mpf_div(a, m, prec, rnd)
|
|
||||||
im = mpf_neg(mpf_div(b, m, prec, rnd))
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_mpf_div(p, z, prec, rnd=round_fast):
|
|
||||||
"""Calculate p/z where p is real efficiently"""
|
|
||||||
a, b = z
|
|
||||||
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b), prec+10)
|
|
||||||
re = mpf_div(mpf_mul(a,p), m, prec, rnd)
|
|
||||||
im = mpf_div(mpf_neg(mpf_mul(b,p)), m, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def complex_int_pow(a, b, n):
|
|
||||||
"""Complex integer power: computes (a+b*I)**n exactly for
|
|
||||||
nonnegative n (a and b must be Python ints)."""
|
|
||||||
wre = 1
|
|
||||||
wim = 0
|
|
||||||
while n:
|
|
||||||
if n & 1:
|
|
||||||
wre, wim = wre*a - wim*b, wim*a + wre*b
|
|
||||||
n -= 1
|
|
||||||
a, b = a*a - b*b, 2*a*b
|
|
||||||
n //= 2
|
|
||||||
return wre, wim
|
|
||||||
|
|
||||||
def mpc_pow(z, w, prec, rnd=round_fast):
|
|
||||||
if w[1] == fzero:
|
|
||||||
return mpc_pow_mpf(z, w[0], prec, rnd)
|
|
||||||
return mpc_exp(mpc_mul(mpc_log(z, prec+10), w, prec+10), prec, rnd)
|
|
||||||
|
|
||||||
def mpc_pow_mpf(z, p, prec, rnd=round_fast):
|
|
||||||
psign, pman, pexp, pbc = p
|
|
||||||
if pexp >= 0:
|
|
||||||
return mpc_pow_int(z, (-1)**psign * (pman<<pexp), prec, rnd)
|
|
||||||
if pexp == -1:
|
|
||||||
sqrtz = mpc_sqrt(z, prec+10)
|
|
||||||
return mpc_pow_int(sqrtz, (-1)**psign * pman, prec, rnd)
|
|
||||||
return mpc_exp(mpc_mul_mpf(mpc_log(z, prec+10), p, prec+10), prec, rnd)
|
|
||||||
|
|
||||||
def mpc_pow_int(z, n, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
if b == fzero:
|
|
||||||
return mpf_pow_int(a, n, prec, rnd), fzero
|
|
||||||
if a == fzero:
|
|
||||||
v = mpf_pow_int(b, n, prec, rnd)
|
|
||||||
n %= 4
|
|
||||||
if n == 0:
|
|
||||||
return v, fzero
|
|
||||||
elif n == 1:
|
|
||||||
return fzero, v
|
|
||||||
elif n == 2:
|
|
||||||
return mpf_neg(v), fzero
|
|
||||||
elif n == 3:
|
|
||||||
return fzero, mpf_neg(v)
|
|
||||||
if n == 0: return mpc_one
|
|
||||||
if n == 1: return mpc_pos(z, prec, rnd)
|
|
||||||
if n == 2: return mpc_square(z, prec, rnd)
|
|
||||||
if n == -1: return mpc_reciprocal(z, prec, rnd)
|
|
||||||
if n < 0: return mpc_reciprocal(mpc_pow_int(z, -n, prec+4), prec, rnd)
|
|
||||||
asign, aman, aexp, abc = a
|
|
||||||
bsign, bman, bexp, bbc = b
|
|
||||||
if asign: aman = -aman
|
|
||||||
if bsign: bman = -bman
|
|
||||||
de = aexp - bexp
|
|
||||||
abs_de = abs(de)
|
|
||||||
exact_size = n*(abs_de + max(abc, bbc))
|
|
||||||
if exact_size < 10000:
|
|
||||||
if de > 0:
|
|
||||||
aman <<= de
|
|
||||||
aexp = bexp
|
|
||||||
else:
|
|
||||||
bman <<= (-de)
|
|
||||||
bexp = aexp
|
|
||||||
re, im = complex_int_pow(aman, bman, n)
|
|
||||||
re = from_man_exp(re, int(n*aexp), prec, rnd)
|
|
||||||
im = from_man_exp(im, int(n*bexp), prec, rnd)
|
|
||||||
return re, im
|
|
||||||
return mpc_exp(mpc_mul_int(mpc_log(z, prec+10), n, prec+10), prec, rnd)
|
|
||||||
|
|
||||||
def mpc_sqrt(z, prec, rnd=round_fast):
|
|
||||||
"""Complex square root (principal branch).
|
|
||||||
|
|
||||||
We have sqrt(a+bi) = sqrt((r+a)/2) + b/sqrt(2*(r+a))*i where
|
|
||||||
r = abs(a+bi), when a+bi is not a negative real number."""
|
|
||||||
a, b = z
|
|
||||||
if b == fzero:
|
|
||||||
if a == fzero:
|
|
||||||
return (a, b)
|
|
||||||
# When a+bi is a negative real number, we get a real sqrt times i
|
|
||||||
if a[0]:
|
|
||||||
im = mpf_sqrt(mpf_neg(a), prec, rnd)
|
|
||||||
return (fzero, im)
|
|
||||||
else:
|
|
||||||
re = mpf_sqrt(a, prec, rnd)
|
|
||||||
return (re, fzero)
|
|
||||||
wp = prec+20
|
|
||||||
if not a[0]: # case a positive
|
|
||||||
t = mpf_add(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) + a
|
|
||||||
u = mpf_shift(t, -1) # u = t/2
|
|
||||||
re = mpf_sqrt(u, prec, rnd) # re = sqrt(u)
|
|
||||||
v = mpf_shift(t, 1) # v = 2*t
|
|
||||||
w = mpf_sqrt(v, wp) # w = sqrt(v)
|
|
||||||
im = mpf_div(b, w, prec, rnd) # im = b / w
|
|
||||||
else: # case a negative
|
|
||||||
t = mpf_sub(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) - a
|
|
||||||
u = mpf_shift(t, -1) # u = t/2
|
|
||||||
im = mpf_sqrt(u, prec, rnd) # im = sqrt(u)
|
|
||||||
v = mpf_shift(t, 1) # v = 2*t
|
|
||||||
w = mpf_sqrt(v, wp) # w = sqrt(v)
|
|
||||||
re = mpf_div(b, w, prec, rnd) # re = b/w
|
|
||||||
if b[0]:
|
|
||||||
re = mpf_neg(re)
|
|
||||||
im = mpf_neg(im)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_nthroot_fixed(a, b, n, prec):
|
|
||||||
# a, b signed integers at fixed precision prec
|
|
||||||
start = 50
|
|
||||||
a1 = int(rshift(a, prec - n*start))
|
|
||||||
b1 = int(rshift(b, prec - n*start))
|
|
||||||
try:
|
|
||||||
r = (a1 + 1j * b1)**(1.0/n)
|
|
||||||
re = r.real
|
|
||||||
im = r.imag
|
|
||||||
re = MPZ(int(re))
|
|
||||||
im = MPZ(int(im))
|
|
||||||
except OverflowError:
|
|
||||||
a1 = from_int(a1, start)
|
|
||||||
b1 = from_int(b1, start)
|
|
||||||
fn = from_int(n)
|
|
||||||
nth = mpf_rdiv_int(1, fn, start)
|
|
||||||
re, im = mpc_pow((a1, b1), (nth, fzero), start)
|
|
||||||
re = to_int(re)
|
|
||||||
im = to_int(im)
|
|
||||||
extra = 10
|
|
||||||
prevp = start
|
|
||||||
extra1 = n
|
|
||||||
for p in giant_steps(start, prec+extra):
|
|
||||||
# this is slow for large n, unlike int_pow_fixed
|
|
||||||
re2, im2 = complex_int_pow(re, im, n-1)
|
|
||||||
re2 = rshift(re2, (n-1)*prevp - p - extra1)
|
|
||||||
im2 = rshift(im2, (n-1)*prevp - p - extra1)
|
|
||||||
r4 = (re2*re2 + im2*im2) >> (p + extra1)
|
|
||||||
ap = rshift(a, prec - p)
|
|
||||||
bp = rshift(b, prec - p)
|
|
||||||
rec = (ap * re2 + bp * im2) >> p
|
|
||||||
imc = (-ap * im2 + bp * re2) >> p
|
|
||||||
reb = (rec << p) // r4
|
|
||||||
imb = (imc << p) // r4
|
|
||||||
re = (reb + (n-1)*lshift(re, p-prevp))//n
|
|
||||||
im = (imb + (n-1)*lshift(im, p-prevp))//n
|
|
||||||
prevp = p
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_nthroot(z, n, prec, rnd=round_fast):
|
|
||||||
"""
|
|
||||||
Complex n-th root.
|
|
||||||
|
|
||||||
Use Newton method as in the real case when it is faster,
|
|
||||||
otherwise use z**(1/n)
|
|
||||||
"""
|
|
||||||
a, b = z
|
|
||||||
if a[0] == 0 and b == fzero:
|
|
||||||
re = mpf_nthroot(a, n, prec, rnd)
|
|
||||||
return (re, fzero)
|
|
||||||
if n < 2:
|
|
||||||
if n == 0:
|
|
||||||
return mpc_one
|
|
||||||
if n == 1:
|
|
||||||
return mpc_pos((a, b), prec, rnd)
|
|
||||||
if n == -1:
|
|
||||||
return mpc_div(mpc_one, (a, b), prec, rnd)
|
|
||||||
inverse = mpc_nthroot((a, b), -n, prec+5, reciprocal_rnd[rnd])
|
|
||||||
return mpc_div(mpc_one, inverse, prec, rnd)
|
|
||||||
if n <= 20:
|
|
||||||
prec2 = int(1.2 * (prec + 10))
|
|
||||||
asign, aman, aexp, abc = a
|
|
||||||
bsign, bman, bexp, bbc = b
|
|
||||||
pf = mpc_abs((a,b), prec)
|
|
||||||
if pf[-2] + pf[-1] > -10 and pf[-2] + pf[-1] < prec:
|
|
||||||
af = to_fixed(a, prec2)
|
|
||||||
bf = to_fixed(b, prec2)
|
|
||||||
re, im = mpc_nthroot_fixed(af, bf, n, prec2)
|
|
||||||
extra = 10
|
|
||||||
re = from_man_exp(re, -prec2-extra, prec2, rnd)
|
|
||||||
im = from_man_exp(im, -prec2-extra, prec2, rnd)
|
|
||||||
return re, im
|
|
||||||
fn = from_int(n)
|
|
||||||
prec2 = prec+10 + 10
|
|
||||||
nth = mpf_rdiv_int(1, fn, prec2)
|
|
||||||
re, im = mpc_pow((a, b), (nth, fzero), prec2, rnd)
|
|
||||||
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
|
|
||||||
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_cbrt((a, b), prec, rnd=round_fast):
|
|
||||||
"""
|
|
||||||
Complex cubic root.
|
|
||||||
"""
|
|
||||||
return mpc_nthroot((a, b), 3, prec, rnd)
|
|
||||||
|
|
||||||
def mpc_exp((a, b), prec, rnd=round_fast):
|
|
||||||
"""
|
|
||||||
Complex exponential function.
|
|
||||||
|
|
||||||
We use the direct formula exp(a+bi) = exp(a) * (cos(b) + sin(b)*i)
|
|
||||||
for the computation. This formula is very nice because it is
|
|
||||||
pefectly stable; since we just do real multiplications, the only
|
|
||||||
numerical errors that can creep in are single-ulp rounding errors.
|
|
||||||
|
|
||||||
The formula is efficient since mpmath's real exp is quite fast and
|
|
||||||
since we can compute cos and sin simultaneously.
|
|
||||||
|
|
||||||
It is no problem if a and b are large; if the implementations of
|
|
||||||
exp/cos/sin are accurate and efficient for all real numbers, then
|
|
||||||
so is this function for all complex numbers.
|
|
||||||
"""
|
|
||||||
if a == fzero:
|
|
||||||
return mpf_cos_sin(b, prec, rnd)
|
|
||||||
mag = mpf_exp(a, prec+4, rnd)
|
|
||||||
c, s = mpf_cos_sin(b, prec+4, rnd)
|
|
||||||
re = mpf_mul(mag, c, prec, rnd)
|
|
||||||
im = mpf_mul(mag, s, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_log(z, prec, rnd=round_fast):
|
|
||||||
re = mpf_log_hypot(z[0], z[1], prec, rnd)
|
|
||||||
im = mpc_arg(z, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_cos((a, b), prec, rnd=round_fast):
|
|
||||||
"""Complex cosine. The formula used is cos(a+bi) = cos(a)*cosh(b) -
|
|
||||||
sin(a)*sinh(b)*i.
|
|
||||||
|
|
||||||
The same comments apply as for the complex exp: only real
|
|
||||||
multiplications are pewrormed, so no cancellation errors are
|
|
||||||
possible. The formula is also efficient since we can compute both
|
|
||||||
pairs (cos, sin) and (cosh, sinh) in single stwps."""
|
|
||||||
if a == fzero:
|
|
||||||
return mpf_cosh(b, prec, rnd), fzero
|
|
||||||
wp = prec + 6
|
|
||||||
c, s = mpf_cos_sin(a, wp)
|
|
||||||
ch, sh = mpf_cosh_sinh(b, wp)
|
|
||||||
re = mpf_mul(c, ch, prec, rnd)
|
|
||||||
im = mpf_mul(s, sh, prec, rnd)
|
|
||||||
return re, mpf_neg(im)
|
|
||||||
|
|
||||||
def mpc_sin((a, b), prec, rnd=round_fast):
|
|
||||||
"""Complex sine. We have sin(a+bi) = sin(a)*cosh(b) +
|
|
||||||
cos(a)*sinh(b)*i. See the docstring for mpc_cos for additional
|
|
||||||
comments."""
|
|
||||||
if a == fzero:
|
|
||||||
return fzero, mpf_sinh(b, prec, rnd)
|
|
||||||
wp = prec + 6
|
|
||||||
c, s = mpf_cos_sin(a, wp)
|
|
||||||
ch, sh = mpf_cosh_sinh(b, wp)
|
|
||||||
re = mpf_mul(s, ch, prec, rnd)
|
|
||||||
im = mpf_mul(c, sh, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_tan(z, prec, rnd=round_fast):
|
|
||||||
"""Complex tangent. Computed as tan(a+bi) = sin(2a)/M + sinh(2b)/M*i
|
|
||||||
where M = cos(2a) + cosh(2b)."""
|
|
||||||
a, b = z
|
|
||||||
asign, aman, aexp, abc = a
|
|
||||||
bsign, bman, bexp, bbc = b
|
|
||||||
if b == fzero: return mpf_tan(a, prec, rnd), fzero
|
|
||||||
if a == fzero: return fzero, mpf_tanh(b, prec, rnd)
|
|
||||||
wp = prec + 15
|
|
||||||
a = mpf_shift(a, 1)
|
|
||||||
b = mpf_shift(b, 1)
|
|
||||||
c, s = mpf_cos_sin(a, wp)
|
|
||||||
ch, sh = mpf_cosh_sinh(b, wp)
|
|
||||||
# TODO: handle cancellation when c ~= -1 and ch ~= 1
|
|
||||||
mag = mpf_add(c, ch, wp)
|
|
||||||
re = mpf_div(s, mag, prec, rnd)
|
|
||||||
im = mpf_div(sh, mag, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_cos_pi((a, b), prec, rnd=round_fast):
|
|
||||||
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
|
|
||||||
if a == fzero:
|
|
||||||
return mpf_cosh(b, prec, rnd), fzero
|
|
||||||
wp = prec + 6
|
|
||||||
c, s = mpf_cos_sin_pi(a, wp)
|
|
||||||
ch, sh = mpf_cosh_sinh(b, wp)
|
|
||||||
re = mpf_mul(c, ch, prec, rnd)
|
|
||||||
im = mpf_mul(s, sh, prec, rnd)
|
|
||||||
return re, mpf_neg(im)
|
|
||||||
|
|
||||||
def mpc_sin_pi((a, b), prec, rnd=round_fast):
|
|
||||||
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
|
|
||||||
if a == fzero:
|
|
||||||
return fzero, mpf_sinh(b, prec, rnd)
|
|
||||||
wp = prec + 6
|
|
||||||
c, s = mpf_cos_sin_pi(a, wp)
|
|
||||||
ch, sh = mpf_cosh_sinh(b, wp)
|
|
||||||
re = mpf_mul(s, ch, prec, rnd)
|
|
||||||
im = mpf_mul(c, sh, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_cosh((a, b), prec, rnd=round_fast):
|
|
||||||
"""Complex hyperbolic cosine. Computed as cosh(z) = cos(z*i)."""
|
|
||||||
return mpc_cos((b, mpf_neg(a)), prec, rnd)
|
|
||||||
|
|
||||||
def mpc_sinh((a, b), prec, rnd=round_fast):
|
|
||||||
"""Complex hyperbolic sine. Computed as sinh(z) = -i*sin(z*i)."""
|
|
||||||
b, a = mpc_sin((b, a), prec, rnd)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpc_tanh((a, b), prec, rnd=round_fast):
|
|
||||||
"""Complex hyperbolic tangent. Computed as tanh(z) = -i*tan(z*i)."""
|
|
||||||
b, a = mpc_tan((b, a), prec, rnd)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
# TODO: avoid loss of accuracy
|
|
||||||
def mpc_atan(z, prec, rnd=round_fast):
|
|
||||||
a, b = z
|
|
||||||
# atan(z) = (I/2)*(log(1-I*z) - log(1+I*z))
|
|
||||||
# x = 1-I*z = 1 + b - I*a
|
|
||||||
# y = 1+I*z = 1 - b + I*a
|
|
||||||
wp = prec + 15
|
|
||||||
x = mpf_add(fone, b, wp), mpf_neg(a)
|
|
||||||
y = mpf_sub(fone, b, wp), a
|
|
||||||
l1 = mpc_log(x, wp)
|
|
||||||
l2 = mpc_log(y, wp)
|
|
||||||
a, b = mpc_sub(l1, l2, prec, rnd)
|
|
||||||
# (I/2) * (a+b*I) = (-b/2 + a/2*I)
|
|
||||||
v = mpf_neg(mpf_shift(b,-1)), mpf_shift(a,-1)
|
|
||||||
# Subtraction at infinity gives correct real part but
|
|
||||||
# wrong imaginary part (should be zero)
|
|
||||||
if v[1] == fnan and mpc_is_inf(z):
|
|
||||||
v = (v[0], fzero)
|
|
||||||
return v
|
|
||||||
|
|
||||||
beta_crossover = from_float(0.6417)
|
|
||||||
alpha_crossover = from_float(1.5)
|
|
||||||
|
|
||||||
def acos_asin(z, prec, rnd, n):
|
|
||||||
""" complex acos for n = 0, asin for n = 1
|
|
||||||
The algorithm is described in
|
|
||||||
T.E. Hull, T.F. Fairgrieve and P.T.P. Tang
|
|
||||||
'Implementing the Complex Arcsine and Arcosine Functions
|
|
||||||
using Exception Handling',
|
|
||||||
ACM Trans. on Math. Software Vol. 23 (1997), p299
|
|
||||||
The complex acos and asin can be defined as
|
|
||||||
acos(z) = acos(beta) - I*sign(a)* log(alpha + sqrt(alpha**2 -1))
|
|
||||||
asin(z) = asin(beta) + I*sign(a)* log(alpha + sqrt(alpha**2 -1))
|
|
||||||
where z = a + I*b
|
|
||||||
alpha = (1/2)*(r + s); beta = (1/2)*(r - s) = a/alpha
|
|
||||||
r = sqrt((a+1)**2 + y**2); s = sqrt((a-1)**2 + y**2)
|
|
||||||
These expressions are rewritten in different ways in different
|
|
||||||
regions, delimited by two crossovers alpha_crossover and beta_crossover,
|
|
||||||
and by abs(a) <= 1, in order to improve the numerical accuracy.
|
|
||||||
"""
|
|
||||||
a, b = z
|
|
||||||
wp = prec + 10
|
|
||||||
# special cases with real argument
|
|
||||||
if b == fzero:
|
|
||||||
am = mpf_sub(fone, mpf_abs(a), wp)
|
|
||||||
# case abs(a) <= 1
|
|
||||||
if not am[0]:
|
|
||||||
if n == 0:
|
|
||||||
return mpf_acos(a, prec, rnd), fzero
|
|
||||||
else:
|
|
||||||
return mpf_asin(a, prec, rnd), fzero
|
|
||||||
# cases abs(a) > 1
|
|
||||||
else:
|
|
||||||
# case a < -1
|
|
||||||
if a[0]:
|
|
||||||
pi = mpf_pi(prec, rnd)
|
|
||||||
c = mpf_acosh(mpf_neg(a), prec, rnd)
|
|
||||||
if n == 0:
|
|
||||||
return pi, mpf_neg(c)
|
|
||||||
else:
|
|
||||||
return mpf_neg(mpf_shift(pi, -1)), c
|
|
||||||
# case a > 1
|
|
||||||
else:
|
|
||||||
c = mpf_acosh(a, prec, rnd)
|
|
||||||
if n == 0:
|
|
||||||
return fzero, c
|
|
||||||
else:
|
|
||||||
pi = mpf_pi(prec, rnd)
|
|
||||||
return mpf_shift(pi, -1), mpf_neg(c)
|
|
||||||
asign = bsign = 0
|
|
||||||
if a[0]:
|
|
||||||
a = mpf_neg(a)
|
|
||||||
asign = 1
|
|
||||||
if b[0]:
|
|
||||||
b = mpf_neg(b)
|
|
||||||
bsign = 1
|
|
||||||
am = mpf_sub(fone, a, wp)
|
|
||||||
ap = mpf_add(fone, a, wp)
|
|
||||||
r = mpf_hypot(ap, b, wp)
|
|
||||||
s = mpf_hypot(am, b, wp)
|
|
||||||
alpha = mpf_shift(mpf_add(r, s, wp), -1)
|
|
||||||
beta = mpf_div(a, alpha, wp)
|
|
||||||
b2 = mpf_mul(b,b, wp)
|
|
||||||
# case beta <= beta_crossover
|
|
||||||
if not mpf_sub(beta_crossover, beta, wp)[0]:
|
|
||||||
if n == 0:
|
|
||||||
re = mpf_acos(beta, wp)
|
|
||||||
else:
|
|
||||||
re = mpf_asin(beta, wp)
|
|
||||||
else:
|
|
||||||
# to compute the real part in this region use the identity
|
|
||||||
# asin(beta) = atan(beta/sqrt(1-beta**2))
|
|
||||||
# beta/sqrt(1-beta**2) = (alpha + a) * (alpha - a)
|
|
||||||
# alpha + a is numerically accurate; alpha - a can have
|
|
||||||
# cancellations leading to numerical inaccuracies, so rewrite
|
|
||||||
# it in differente ways according to the region
|
|
||||||
Ax = mpf_add(alpha, a, wp)
|
|
||||||
# case a <= 1
|
|
||||||
if not am[0]:
|
|
||||||
# c = b*b/(r + (a+1)); d = (s + (1-a))
|
|
||||||
# alpha - a = (1/2)*(c + d)
|
|
||||||
# case n=0: re = atan(sqrt((1/2) * Ax * (c + d))/a)
|
|
||||||
# case n=1: re = atan(a/sqrt((1/2) * Ax * (c + d)))
|
|
||||||
c = mpf_div(b2, mpf_add(r, ap, wp), wp)
|
|
||||||
d = mpf_add(s, am, wp)
|
|
||||||
re = mpf_shift(mpf_mul(Ax, mpf_add(c, d, wp), wp), -1)
|
|
||||||
if n == 0:
|
|
||||||
re = mpf_atan(mpf_div(mpf_sqrt(re, wp), a, wp), wp)
|
|
||||||
else:
|
|
||||||
re = mpf_atan(mpf_div(a, mpf_sqrt(re, wp), wp), wp)
|
|
||||||
else:
|
|
||||||
# c = Ax/(r + (a+1)); d = Ax/(s - (1-a))
|
|
||||||
# alpha - a = (1/2)*(c + d)
|
|
||||||
# case n = 0: re = atan(b*sqrt(c + d)/2/a)
|
|
||||||
# case n = 1: re = atan(a/(b*sqrt(c + d)/2)
|
|
||||||
c = mpf_div(Ax, mpf_add(r, ap, wp), wp)
|
|
||||||
d = mpf_div(Ax, mpf_sub(s, am, wp), wp)
|
|
||||||
re = mpf_shift(mpf_add(c, d, wp), -1)
|
|
||||||
re = mpf_mul(b, mpf_sqrt(re, wp), wp)
|
|
||||||
if n == 0:
|
|
||||||
re = mpf_atan(mpf_div(re, a, wp), wp)
|
|
||||||
else:
|
|
||||||
re = mpf_atan(mpf_div(a, re, wp), wp)
|
|
||||||
# to compute alpha + sqrt(alpha**2 - 1), if alpha <= alpha_crossover
|
|
||||||
# replace it with 1 + Am1 + sqrt(Am1*(alpha+1)))
|
|
||||||
# where Am1 = alpha -1
|
|
||||||
# if alpha <= alpha_crossover:
|
|
||||||
if not mpf_sub(alpha_crossover, alpha, wp)[0]:
|
|
||||||
c1 = mpf_div(b2, mpf_add(r, ap, wp), wp)
|
|
||||||
# case a < 1
|
|
||||||
if mpf_neg(am)[0]:
|
|
||||||
# Am1 = (1/2) * (b*b/(r + (a+1)) + b*b/(s + (1-a))
|
|
||||||
c2 = mpf_add(s, am, wp)
|
|
||||||
c2 = mpf_div(b2, c2, wp)
|
|
||||||
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
|
|
||||||
else:
|
|
||||||
# Am1 = (1/2) * (b*b/(r + (a+1)) + (s - (1-a)))
|
|
||||||
c2 = mpf_sub(s, am, wp)
|
|
||||||
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
|
|
||||||
# im = log(1 + Am1 + sqrt(Am1*(alpha+1)))
|
|
||||||
im = mpf_mul(Am1, mpf_add(alpha, fone, wp), wp)
|
|
||||||
im = mpf_log(mpf_add(fone, mpf_add(Am1, mpf_sqrt(im, wp), wp), wp), wp)
|
|
||||||
else:
|
|
||||||
# im = log(alpha + sqrt(alpha*alpha - 1))
|
|
||||||
im = mpf_sqrt(mpf_sub(mpf_mul(alpha, alpha, wp), fone, wp), wp)
|
|
||||||
im = mpf_log(mpf_add(alpha, im, wp), wp)
|
|
||||||
if asign:
|
|
||||||
if n == 0:
|
|
||||||
re = mpf_sub(mpf_pi(wp), re, wp)
|
|
||||||
else:
|
|
||||||
re = mpf_neg(re)
|
|
||||||
if not bsign and n == 0:
|
|
||||||
im = mpf_neg(im)
|
|
||||||
if bsign and n == 1:
|
|
||||||
im = mpf_neg(im)
|
|
||||||
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
|
|
||||||
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpc_acos(z, prec, rnd=round_fast):
|
|
||||||
return acos_asin(z, prec, rnd, 0)
|
|
||||||
|
|
||||||
def mpc_asin(z, prec, rnd=round_fast):
|
|
||||||
return acos_asin(z, prec, rnd, 1)
|
|
||||||
|
|
||||||
def mpc_asinh(z, prec, rnd=round_fast):
|
|
||||||
# asinh(z) = I * asin(-I z)
|
|
||||||
a, b = z
|
|
||||||
a, b = mpc_asin((b, mpf_neg(a)), prec, rnd)
|
|
||||||
return mpf_neg(b), a
|
|
||||||
|
|
||||||
def mpc_acosh(z, prec, rnd=round_fast):
|
|
||||||
# acosh(z) = -I * acos(z) for Im(acos(z)) <= 0
|
|
||||||
# +I * acos(z) otherwise
|
|
||||||
a, b = mpc_acos(z, prec, rnd)
|
|
||||||
if b[0] or b == fzero:
|
|
||||||
return mpf_neg(b), a
|
|
||||||
else:
|
|
||||||
return b, mpf_neg(a)
|
|
||||||
|
|
||||||
def mpc_atanh(z, prec, rnd=round_fast):
|
|
||||||
# atanh(z) = (log(1+z)-log(1-z))/2
|
|
||||||
wp = prec + 15
|
|
||||||
a = mpc_add(z, mpc_one, wp)
|
|
||||||
b = mpc_sub(mpc_one, z, wp)
|
|
||||||
a = mpc_log(a, wp)
|
|
||||||
b = mpc_log(b, wp)
|
|
||||||
v = mpc_shift(mpc_sub(a, b, wp), -1)
|
|
||||||
# Subtraction at infinity gives correct imaginary part but
|
|
||||||
# wrong real part (should be zero)
|
|
||||||
if v[0] == fnan and mpc_is_inf(z):
|
|
||||||
v = (fzero, v[1])
|
|
||||||
return v
|
|
||||||
|
|
||||||
def mpc_fibonacci(z, prec, rnd=round_fast):
|
|
||||||
re, im = z
|
|
||||||
if im == fzero:
|
|
||||||
return (mpf_fibonacci(re, prec, rnd), fzero)
|
|
||||||
size = max(abs(re[2]+re[3]), abs(re[2]+re[3]))
|
|
||||||
wp = prec + size + 20
|
|
||||||
a = mpf_phi(wp)
|
|
||||||
b = mpf_add(mpf_shift(a, 1), fnone, wp)
|
|
||||||
u = mpc_pow((a, fzero), z, wp)
|
|
||||||
v = mpc_cos_pi(z, wp)
|
|
||||||
v = mpc_div(v, u, wp)
|
|
||||||
u = mpc_sub(u, v, wp)
|
|
||||||
u = mpc_div_mpf(u, b, prec, rnd)
|
|
||||||
return u
|
|
||||||
|
|
||||||
def mpf_expj(x, prec, rnd='f'):
|
|
||||||
raise ComplexResult
|
|
||||||
|
|
||||||
def mpc_expj(z, prec, rnd='f'):
|
|
||||||
re, im = z
|
|
||||||
if im == fzero:
|
|
||||||
return mpf_cos_sin(re, prec, rnd)
|
|
||||||
if re == fzero:
|
|
||||||
return mpf_exp(mpf_neg(im), prec, rnd), fzero
|
|
||||||
ey = mpf_exp(mpf_neg(im), prec+10)
|
|
||||||
c, s = mpf_cos_sin(re, prec+10)
|
|
||||||
re = mpf_mul(ey, c, prec, rnd)
|
|
||||||
im = mpf_mul(ey, s, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
|
|
||||||
def mpf_expjpi(x, prec, rnd='f'):
|
|
||||||
raise ComplexResult
|
|
||||||
|
|
||||||
def mpc_expjpi(z, prec, rnd='f'):
|
|
||||||
re, im = z
|
|
||||||
if im == fzero:
|
|
||||||
return mpf_cos_sin_pi(re, prec, rnd)
|
|
||||||
sign, man, exp, bc = im
|
|
||||||
wp = prec+10
|
|
||||||
if man:
|
|
||||||
wp += max(0, exp+bc)
|
|
||||||
im = mpf_neg(mpf_mul(mpf_pi(wp), im, wp))
|
|
||||||
if re == fzero:
|
|
||||||
return mpf_exp(im, prec, rnd), fzero
|
|
||||||
ey = mpf_exp(im, prec+10)
|
|
||||||
c, s = mpf_cos_sin_pi(re, prec+10)
|
|
||||||
re = mpf_mul(ey, c, prec, rnd)
|
|
||||||
im = mpf_mul(ey, s, prec, rnd)
|
|
||||||
return re, im
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,348 +0,0 @@
|
||||||
"""
|
|
||||||
Computational functions for interval arithmetic.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from libmpf import (
|
|
||||||
ComplexResult,
|
|
||||||
round_down, round_up, round_floor, round_ceiling, round_nearest,
|
|
||||||
prec_to_dps, repr_dps,
|
|
||||||
fnan, finf, fninf, fzero, fhalf, fone, fnone,
|
|
||||||
mpf_sign, mpf_lt, mpf_le, mpf_gt, mpf_ge, mpf_eq, mpf_cmp,
|
|
||||||
mpf_floor, from_int, to_int, to_str,
|
|
||||||
mpf_abs, mpf_neg, mpf_pos, mpf_add, mpf_sub, mpf_mul,
|
|
||||||
mpf_div, mpf_shift, mpf_pow_int)
|
|
||||||
|
|
||||||
from libelefun import (
|
|
||||||
mpf_log, mpf_exp, mpf_sqrt, reduce_angle, calc_cos_sin
|
|
||||||
)
|
|
||||||
|
|
||||||
def mpi_str(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
dps = prec_to_dps(prec) + 5
|
|
||||||
return "[%s, %s]" % (to_str(sa, dps), to_str(sb, dps))
|
|
||||||
|
|
||||||
#dps = prec_to_dps(prec)
|
|
||||||
#m = mpi_mid(s, prec)
|
|
||||||
#d = mpf_shift(mpi_delta(s, 20), -1)
|
|
||||||
#return "%s +/- %s" % (to_str(m, dps), to_str(d, 3))
|
|
||||||
|
|
||||||
def mpi_add(s, t, prec):
|
|
||||||
sa, sb = s
|
|
||||||
ta, tb = t
|
|
||||||
a = mpf_add(sa, ta, prec, round_floor)
|
|
||||||
b = mpf_add(sb, tb, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = finf
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_sub(s, t, prec):
|
|
||||||
sa, sb = s
|
|
||||||
ta, tb = t
|
|
||||||
a = mpf_sub(sa, tb, prec, round_floor)
|
|
||||||
b = mpf_sub(sb, ta, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = finf
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_delta(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
return mpf_sub(sb, sa, prec, round_up)
|
|
||||||
|
|
||||||
def mpi_mid(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
return mpf_shift(mpf_add(sa, sb, prec, round_nearest), -1)
|
|
||||||
|
|
||||||
def mpi_pos(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
a = mpf_pos(sa, prec, round_floor)
|
|
||||||
b = mpf_pos(sb, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_neg(s, prec=None):
|
|
||||||
sa, sb = s
|
|
||||||
a = mpf_neg(sb, prec, round_floor)
|
|
||||||
b = mpf_neg(sa, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_abs(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
sas = mpf_sign(sa)
|
|
||||||
sbs = mpf_sign(sb)
|
|
||||||
# Both points nonnegative?
|
|
||||||
if sas >= 0:
|
|
||||||
a = mpf_pos(sa, prec, round_floor)
|
|
||||||
b = mpf_pos(sb, prec, round_ceiling)
|
|
||||||
# Upper point nonnegative?
|
|
||||||
elif sbs >= 0:
|
|
||||||
a = fzero
|
|
||||||
negsa = mpf_neg(sa)
|
|
||||||
if mpf_lt(negsa, sb):
|
|
||||||
b = mpf_pos(sb, prec, round_ceiling)
|
|
||||||
else:
|
|
||||||
b = mpf_pos(negsa, prec, round_ceiling)
|
|
||||||
# Both negative?
|
|
||||||
else:
|
|
||||||
a = mpf_neg(sb, prec, round_floor)
|
|
||||||
b = mpf_neg(sa, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_mul(s, t, prec):
|
|
||||||
sa, sb = s
|
|
||||||
ta, tb = t
|
|
||||||
sas = mpf_sign(sa)
|
|
||||||
sbs = mpf_sign(sb)
|
|
||||||
tas = mpf_sign(ta)
|
|
||||||
tbs = mpf_sign(tb)
|
|
||||||
if sas == sbs == 0:
|
|
||||||
# Should maybe be undefined
|
|
||||||
if ta == fninf or tb == finf:
|
|
||||||
return fninf, finf
|
|
||||||
return fzero, fzero
|
|
||||||
if tas == tbs == 0:
|
|
||||||
# Should maybe be undefined
|
|
||||||
if sa == fninf or sb == finf:
|
|
||||||
return fninf, finf
|
|
||||||
return fzero, fzero
|
|
||||||
if sas >= 0:
|
|
||||||
# positive * positive
|
|
||||||
if tas >= 0:
|
|
||||||
a = mpf_mul(sa, ta, prec, round_floor)
|
|
||||||
b = mpf_mul(sb, tb, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fzero
|
|
||||||
if b == fnan: b = finf
|
|
||||||
# positive * negative
|
|
||||||
elif tbs <= 0:
|
|
||||||
a = mpf_mul(sb, ta, prec, round_floor)
|
|
||||||
b = mpf_mul(sa, tb, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = fzero
|
|
||||||
# positive * both signs
|
|
||||||
else:
|
|
||||||
a = mpf_mul(sb, ta, prec, round_floor)
|
|
||||||
b = mpf_mul(sb, tb, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = finf
|
|
||||||
elif sbs <= 0:
|
|
||||||
# negative * positive
|
|
||||||
if tas >= 0:
|
|
||||||
a = mpf_mul(sa, tb, prec, round_floor)
|
|
||||||
b = mpf_mul(sb, ta, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = fzero
|
|
||||||
# negative * negative
|
|
||||||
elif tbs <= 0:
|
|
||||||
a = mpf_mul(sb, tb, prec, round_floor)
|
|
||||||
b = mpf_mul(sa, ta, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fzero
|
|
||||||
if b == fnan: b = finf
|
|
||||||
# negative * both signs
|
|
||||||
else:
|
|
||||||
a = mpf_mul(sa, tb, prec, round_floor)
|
|
||||||
b = mpf_mul(sa, ta, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = finf
|
|
||||||
else:
|
|
||||||
# General case: perform all cross-multiplications and compare
|
|
||||||
# Since the multiplications can be done exactly, we need only
|
|
||||||
# do 4 (instead of 8: two for each rounding mode)
|
|
||||||
cases = [mpf_mul(sa, ta), mpf_mul(sa, tb), mpf_mul(sb, ta), mpf_mul(sb, tb)]
|
|
||||||
if fnan in cases:
|
|
||||||
a, b = (fninf, finf)
|
|
||||||
else:
|
|
||||||
cases = sorted(cases, cmp=mpf_cmp)
|
|
||||||
a = mpf_pos(cases[0], prec, round_floor)
|
|
||||||
b = mpf_pos(cases[-1], prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_div(s, t, prec):
|
|
||||||
sa, sb = s
|
|
||||||
ta, tb = t
|
|
||||||
sas = mpf_sign(sa)
|
|
||||||
sbs = mpf_sign(sb)
|
|
||||||
tas = mpf_sign(ta)
|
|
||||||
tbs = mpf_sign(tb)
|
|
||||||
# 0 / X
|
|
||||||
if sas == sbs == 0:
|
|
||||||
# 0 / <interval containing 0>
|
|
||||||
if (tas < 0 and tbs > 0) or (tas == 0 or tbs == 0):
|
|
||||||
return fninf, finf
|
|
||||||
return fzero, fzero
|
|
||||||
# Denominator contains both negative and positive numbers;
|
|
||||||
# this should properly be a multi-interval, but the closest
|
|
||||||
# match is the entire (extended) real line
|
|
||||||
if tas < 0 and tbs > 0:
|
|
||||||
return fninf, finf
|
|
||||||
# Assume denominator to be nonnegative
|
|
||||||
if tas < 0:
|
|
||||||
return mpi_div(mpi_neg(s), mpi_neg(t), prec)
|
|
||||||
# Division by zero
|
|
||||||
# XXX: make sure all results make sense
|
|
||||||
if tas == 0:
|
|
||||||
# Numerator contains both signs?
|
|
||||||
if sas < 0 and sbs > 0:
|
|
||||||
return fninf, finf
|
|
||||||
if tas == tbs:
|
|
||||||
return fninf, finf
|
|
||||||
# Numerator positive?
|
|
||||||
if sas >= 0:
|
|
||||||
a = mpf_div(sa, tb, prec, round_floor)
|
|
||||||
b = finf
|
|
||||||
if sbs <= 0:
|
|
||||||
a = fninf
|
|
||||||
b = mpf_div(sb, tb, prec, round_ceiling)
|
|
||||||
# Division with positive denominator
|
|
||||||
# We still have to handle nans resulting from inf/0 or inf/inf
|
|
||||||
else:
|
|
||||||
# Nonnegative numerator
|
|
||||||
if sas >= 0:
|
|
||||||
a = mpf_div(sa, tb, prec, round_floor)
|
|
||||||
b = mpf_div(sb, ta, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fzero
|
|
||||||
if b == fnan: b = finf
|
|
||||||
# Nonpositive numerator
|
|
||||||
elif sbs <= 0:
|
|
||||||
a = mpf_div(sa, ta, prec, round_floor)
|
|
||||||
b = mpf_div(sb, tb, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = fzero
|
|
||||||
# Numerator contains both signs?
|
|
||||||
else:
|
|
||||||
a = mpf_div(sa, ta, prec, round_floor)
|
|
||||||
b = mpf_div(sb, ta, prec, round_ceiling)
|
|
||||||
if a == fnan: a = fninf
|
|
||||||
if b == fnan: b = finf
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_exp(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
# exp is monotonous
|
|
||||||
a = mpf_exp(sa, prec, round_floor)
|
|
||||||
b = mpf_exp(sb, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_log(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
# log is monotonous
|
|
||||||
a = mpf_log(sa, prec, round_floor)
|
|
||||||
b = mpf_log(sb, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_sqrt(s, prec):
|
|
||||||
sa, sb = s
|
|
||||||
# sqrt is monotonous
|
|
||||||
a = mpf_sqrt(sa, prec, round_floor)
|
|
||||||
b = mpf_sqrt(sb, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_pow_int(s, n, prec):
|
|
||||||
sa, sb = s
|
|
||||||
if n < 0:
|
|
||||||
return mpi_div((fone, fone), mpi_pow_int(s, -n, prec+20), prec)
|
|
||||||
if n == 0:
|
|
||||||
return (fone, fone)
|
|
||||||
if n == 1:
|
|
||||||
return s
|
|
||||||
# Odd -- signs are preserved
|
|
||||||
if n & 1:
|
|
||||||
a = mpf_pow_int(sa, n, prec, round_floor)
|
|
||||||
b = mpf_pow_int(sb, n, prec, round_ceiling)
|
|
||||||
# Even -- important to ensure positivity
|
|
||||||
else:
|
|
||||||
sas = mpf_sign(sa)
|
|
||||||
sbs = mpf_sign(sb)
|
|
||||||
# Nonnegative?
|
|
||||||
if sas >= 0:
|
|
||||||
a = mpf_pow_int(sa, n, prec, round_floor)
|
|
||||||
b = mpf_pow_int(sb, n, prec, round_ceiling)
|
|
||||||
# Nonpositive?
|
|
||||||
elif sbs <= 0:
|
|
||||||
a = mpf_pow_int(sb, n, prec, round_floor)
|
|
||||||
b = mpf_pow_int(sa, n, prec, round_ceiling)
|
|
||||||
# Mixed signs?
|
|
||||||
else:
|
|
||||||
a = fzero
|
|
||||||
# max(-a,b)**n
|
|
||||||
sa = mpf_neg(sa)
|
|
||||||
if mpf_ge(sa, sb):
|
|
||||||
b = mpf_pow_int(sa, n, prec, round_ceiling)
|
|
||||||
else:
|
|
||||||
b = mpf_pow_int(sb, n, prec, round_ceiling)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def mpi_pow(s, t, prec):
|
|
||||||
ta, tb = t
|
|
||||||
if ta == tb and ta not in (finf, fninf):
|
|
||||||
if ta == from_int(to_int(ta)):
|
|
||||||
return mpi_pow_int(s, to_int(ta), prec)
|
|
||||||
if ta == fhalf:
|
|
||||||
return mpi_sqrt(s, prec)
|
|
||||||
u = mpi_log(s, prec + 20)
|
|
||||||
v = mpi_mul(u, t, prec + 20)
|
|
||||||
return mpi_exp(v, prec)
|
|
||||||
|
|
||||||
def MIN(x, y):
|
|
||||||
if mpf_le(x, y):
|
|
||||||
return x
|
|
||||||
return y
|
|
||||||
|
|
||||||
def MAX(x, y):
|
|
||||||
if mpf_ge(x, y):
|
|
||||||
return x
|
|
||||||
return y
|
|
||||||
|
|
||||||
def mpi_cos_sin(x, prec):
|
|
||||||
a, b = x
|
|
||||||
# Guaranteed to contain both -1 and 1
|
|
||||||
if finf in (a, b) or fninf in (a, b):
|
|
||||||
return (fnone, fone), (fnone, fone)
|
|
||||||
y, yswaps, yn = reduce_angle(a, prec+20)
|
|
||||||
z, zswaps, zn = reduce_angle(b, prec+20)
|
|
||||||
# Guaranteed to contain both -1 and 1
|
|
||||||
if zn - yn >= 4:
|
|
||||||
return (fnone, fone), (fnone, fone)
|
|
||||||
# Both points in the same quadrant -- cos and sin both strictly monotonous
|
|
||||||
if yn == zn:
|
|
||||||
m = yn % 4
|
|
||||||
if m == 0:
|
|
||||||
cb, sa = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_floor)
|
|
||||||
ca, sb = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
|
|
||||||
if m == 1:
|
|
||||||
cb, sb = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_ceiling)
|
|
||||||
ca, sa = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
|
|
||||||
if m == 2:
|
|
||||||
ca, sb = calc_cos_sin(0, y, yswaps, prec, round_floor, round_ceiling)
|
|
||||||
cb, sa = calc_cos_sin(0, z, zswaps, prec, round_ceiling, round_floor)
|
|
||||||
if m == 3:
|
|
||||||
ca, sa = calc_cos_sin(0, y, yswaps, prec, round_floor, round_floor)
|
|
||||||
cb, sb = calc_cos_sin(0, z, zswaps, prec, round_ceiling, round_ceiling)
|
|
||||||
return (ca, cb), (sa, sb)
|
|
||||||
# Intervals spanning multiple quadrants
|
|
||||||
yn %= 4
|
|
||||||
zn %= 4
|
|
||||||
case = (yn, zn)
|
|
||||||
if case == (0, 1):
|
|
||||||
cb, sy = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_floor)
|
|
||||||
ca, sz = calc_cos_sin(0, z, zswaps, prec, round_floor, round_floor)
|
|
||||||
return (ca, cb), (MIN(sy, sz), fone)
|
|
||||||
if case == (3, 0):
|
|
||||||
cy, sa = calc_cos_sin(0, y, yswaps, prec, round_floor, round_floor)
|
|
||||||
cz, sb = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
|
|
||||||
return (MIN(cy, cz), fone), (sa, sb)
|
|
||||||
|
|
||||||
|
|
||||||
raise NotImplementedError("cos/sin spanning multiple quadrants")
|
|
||||||
|
|
||||||
def mpi_cos(x, prec):
|
|
||||||
return mpi_cos_sin(x, prec)[0]
|
|
||||||
|
|
||||||
def mpi_sin(x, prec):
|
|
||||||
return mpi_cos_sin(x, prec)[1]
|
|
||||||
|
|
||||||
def mpi_tan(x, prec):
|
|
||||||
cos, sin = mpi_cos_sin(x, prec+20)
|
|
||||||
return mpi_div(sin, cos, prec)
|
|
||||||
|
|
||||||
def mpi_cot(x, prec):
|
|
||||||
cos, sin = mpi_cos_sin(x, prec+20)
|
|
||||||
return mpi_div(cos, sin, prec)
|
|
||||||
|
|
@ -1,645 +0,0 @@
|
||||||
"""
|
|
||||||
This module complements the math and cmath builtin modules by providing
|
|
||||||
fast machine precision versions of some additional functions (gamma, ...)
|
|
||||||
and wrapping math/cmath functions so that they can be called with either
|
|
||||||
real or complex arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import operator
|
|
||||||
import math
|
|
||||||
import cmath
|
|
||||||
|
|
||||||
# Irrational (?) constants
|
|
||||||
pi = 3.1415926535897932385
|
|
||||||
e = 2.7182818284590452354
|
|
||||||
sqrt2 = 1.4142135623730950488
|
|
||||||
sqrt5 = 2.2360679774997896964
|
|
||||||
phi = 1.6180339887498948482
|
|
||||||
ln2 = 0.69314718055994530942
|
|
||||||
ln10 = 2.302585092994045684
|
|
||||||
euler = 0.57721566490153286061
|
|
||||||
catalan = 0.91596559417721901505
|
|
||||||
khinchin = 2.6854520010653064453
|
|
||||||
apery = 1.2020569031595942854
|
|
||||||
|
|
||||||
logpi = 1.1447298858494001741
|
|
||||||
|
|
||||||
def _mathfun_real(f_real, f_complex):
|
|
||||||
def f(x, **kwargs):
|
|
||||||
if type(x) is float:
|
|
||||||
return f_real(x)
|
|
||||||
if type(x) is complex:
|
|
||||||
return f_complex(x)
|
|
||||||
try:
|
|
||||||
x = float(x)
|
|
||||||
return f_real(x)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
x = complex(x)
|
|
||||||
return f_complex(x)
|
|
||||||
f.__name__ = f_real.__name__
|
|
||||||
return f
|
|
||||||
|
|
||||||
def _mathfun(f_real, f_complex):
|
|
||||||
def f(x, **kwargs):
|
|
||||||
if type(x) is complex:
|
|
||||||
return f_complex(x)
|
|
||||||
try:
|
|
||||||
return f_real(float(x))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return f_complex(complex(x))
|
|
||||||
f.__name__ = f_real.__name__
|
|
||||||
return f
|
|
||||||
|
|
||||||
def _mathfun_n(f_real, f_complex):
|
|
||||||
def f(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
return f_real(*(float(x) for x in args))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return f_complex(*(complex(x) for x in args))
|
|
||||||
f.__name__ = f_real.__name__
|
|
||||||
return f
|
|
||||||
|
|
||||||
pow = _mathfun_n(operator.pow, lambda x, y: complex(x)**y)
|
|
||||||
log = _mathfun_n(math.log, cmath.log)
|
|
||||||
sqrt = _mathfun(math.sqrt, cmath.sqrt)
|
|
||||||
exp = _mathfun_real(math.exp, cmath.exp)
|
|
||||||
|
|
||||||
cos = _mathfun_real(math.cos, cmath.cos)
|
|
||||||
sin = _mathfun_real(math.sin, cmath.sin)
|
|
||||||
tan = _mathfun_real(math.tan, cmath.tan)
|
|
||||||
|
|
||||||
acos = _mathfun(math.acos, cmath.acos)
|
|
||||||
asin = _mathfun(math.asin, cmath.asin)
|
|
||||||
atan = _mathfun_real(math.atan, cmath.atan)
|
|
||||||
|
|
||||||
cosh = _mathfun_real(math.cosh, cmath.cosh)
|
|
||||||
sinh = _mathfun_real(math.sinh, cmath.sinh)
|
|
||||||
tanh = _mathfun_real(math.tanh, cmath.tanh)
|
|
||||||
|
|
||||||
floor = _mathfun_real(math.floor,
|
|
||||||
lambda z: complex(math.floor(z.real), math.floor(z.imag)))
|
|
||||||
ceil = _mathfun_real(math.ceil,
|
|
||||||
lambda z: complex(math.ceil(z.real), math.ceil(z.imag)))
|
|
||||||
|
|
||||||
|
|
||||||
cos_sin = _mathfun_real(lambda x: (math.cos(x), math.sin(x)),
|
|
||||||
lambda z: (cmath.cos(z), cmath.sin(z)))
|
|
||||||
|
|
||||||
cbrt = _mathfun(lambda x: x**(1./3), lambda z: z**(1./3))
|
|
||||||
|
|
||||||
def nthroot(x, n):
|
|
||||||
r = 1./n
|
|
||||||
try:
|
|
||||||
return float(x) ** r
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return complex(x) ** r
|
|
||||||
|
|
||||||
def _sinpi_real(x):
|
|
||||||
if x < 0:
|
|
||||||
return -_sinpi_real(-x)
|
|
||||||
n, r = divmod(x, 0.5)
|
|
||||||
r *= pi
|
|
||||||
n %= 4
|
|
||||||
if n == 0: return math.sin(r)
|
|
||||||
if n == 1: return math.cos(r)
|
|
||||||
if n == 2: return -math.sin(r)
|
|
||||||
if n == 3: return -math.cos(r)
|
|
||||||
|
|
||||||
def _cospi_real(x):
|
|
||||||
if x < 0:
|
|
||||||
x = -x
|
|
||||||
n, r = divmod(x, 0.5)
|
|
||||||
r *= pi
|
|
||||||
n %= 4
|
|
||||||
if n == 0: return math.cos(r)
|
|
||||||
if n == 1: return -math.sin(r)
|
|
||||||
if n == 2: return -math.cos(r)
|
|
||||||
if n == 3: return math.sin(r)
|
|
||||||
|
|
||||||
def _sinpi_complex(z):
|
|
||||||
if z.real < 0:
|
|
||||||
return -_sinpi_complex(-z)
|
|
||||||
n, r = divmod(z.real, 0.5)
|
|
||||||
z = pi*complex(r, z.imag)
|
|
||||||
n %= 4
|
|
||||||
if n == 0: return cmath.sin(z)
|
|
||||||
if n == 1: return cmath.cos(z)
|
|
||||||
if n == 2: return -cmath.sin(z)
|
|
||||||
if n == 3: return -cmath.cos(z)
|
|
||||||
|
|
||||||
def _cospi_complex(z):
|
|
||||||
if z.real < 0:
|
|
||||||
z = -z
|
|
||||||
n, r = divmod(z.real, 0.5)
|
|
||||||
z = pi*complex(r, z.imag)
|
|
||||||
n %= 4
|
|
||||||
if n == 0: return cmath.cos(z)
|
|
||||||
if n == 1: return -cmath.sin(z)
|
|
||||||
if n == 2: return -cmath.cos(z)
|
|
||||||
if n == 3: return cmath.sin(z)
|
|
||||||
|
|
||||||
cospi = _mathfun_real(_cospi_real, _cospi_complex)
|
|
||||||
sinpi = _mathfun_real(_sinpi_real, _sinpi_complex)
|
|
||||||
|
|
||||||
def tanpi(x):
|
|
||||||
try:
|
|
||||||
return sinpi(x) / cospi(x)
|
|
||||||
except OverflowError:
|
|
||||||
if complex(x).imag > 10:
|
|
||||||
return 1j
|
|
||||||
if complex(x).imag < 10:
|
|
||||||
return -1j
|
|
||||||
raise
|
|
||||||
|
|
||||||
def cotpi(x):
|
|
||||||
try:
|
|
||||||
return cospi(x) / sinpi(x)
|
|
||||||
except OverflowError:
|
|
||||||
if complex(x).imag > 10:
|
|
||||||
return -1j
|
|
||||||
if complex(x).imag < 10:
|
|
||||||
return 1j
|
|
||||||
raise
|
|
||||||
|
|
||||||
INF = 1e300*1e300
|
|
||||||
NINF = -INF
|
|
||||||
NAN = INF-INF
|
|
||||||
EPS = 2.2204460492503131e-16
|
|
||||||
|
|
||||||
_exact_gamma = (INF, 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0,
|
|
||||||
362880.0, 3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
|
|
||||||
1307674368000.0, 20922789888000.0, 355687428096000.0, 6402373705728000.0,
|
|
||||||
121645100408832000.0, 2432902008176640000.0)
|
|
||||||
|
|
||||||
_max_exact_gamma = len(_exact_gamma)-1
|
|
||||||
|
|
||||||
# Lanczos coefficients used by the GNU Scientific Library
|
|
||||||
_lanczos_g = 7
|
|
||||||
_lanczos_p = (0.99999999999980993, 676.5203681218851, -1259.1392167224028,
|
|
||||||
771.32342877765313, -176.61502916214059, 12.507343278686905,
|
|
||||||
-0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7)
|
|
||||||
|
|
||||||
def _gamma_real(x):
|
|
||||||
_intx = int(x)
|
|
||||||
if _intx == x:
|
|
||||||
if _intx <= 0:
|
|
||||||
#return (-1)**_intx * INF
|
|
||||||
raise ZeroDivisionError("gamma function pole")
|
|
||||||
if _intx <= _max_exact_gamma:
|
|
||||||
return _exact_gamma[_intx]
|
|
||||||
if x < 0.5:
|
|
||||||
# TODO: sinpi
|
|
||||||
return pi / (_sinpi_real(x)*_gamma_real(1-x))
|
|
||||||
else:
|
|
||||||
x -= 1.0
|
|
||||||
r = _lanczos_p[0]
|
|
||||||
for i in range(1, _lanczos_g+2):
|
|
||||||
r += _lanczos_p[i]/(x+i)
|
|
||||||
t = x + _lanczos_g + 0.5
|
|
||||||
return 2.506628274631000502417 * t**(x+0.5) * math.exp(-t) * r
|
|
||||||
|
|
||||||
def _gamma_complex(x):
|
|
||||||
if not x.imag:
|
|
||||||
return complex(_gamma_real(x.real))
|
|
||||||
if x.real < 0.5:
|
|
||||||
# TODO: sinpi
|
|
||||||
return pi / (_sinpi_complex(x)*_gamma_complex(1-x))
|
|
||||||
else:
|
|
||||||
x -= 1.0
|
|
||||||
r = _lanczos_p[0]
|
|
||||||
for i in range(1, _lanczos_g+2):
|
|
||||||
r += _lanczos_p[i]/(x+i)
|
|
||||||
t = x + _lanczos_g + 0.5
|
|
||||||
return 2.506628274631000502417 * t**(x+0.5) * cmath.exp(-t) * r
|
|
||||||
|
|
||||||
gamma = _mathfun_real(_gamma_real, _gamma_complex)
|
|
||||||
|
|
||||||
def factorial(x):
|
|
||||||
return gamma(x+1.0)
|
|
||||||
|
|
||||||
def arg(x):
|
|
||||||
if type(x) is float:
|
|
||||||
return math.atan2(0.0,x)
|
|
||||||
return math.atan2(x.imag,x.real)
|
|
||||||
|
|
||||||
# XXX: broken for negatives
|
|
||||||
def loggamma(x):
|
|
||||||
if type(x) not in (float, complex):
|
|
||||||
try:
|
|
||||||
x = float(x)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
x = complex(x)
|
|
||||||
try:
|
|
||||||
xreal = x.real
|
|
||||||
ximag = x.imag
|
|
||||||
except AttributeError: # py2.5
|
|
||||||
xreal = x
|
|
||||||
ximag = 0.0
|
|
||||||
# Reflection formula
|
|
||||||
# http://functions.wolfram.com/GammaBetaErf/LogGamma/16/01/01/0003/
|
|
||||||
if xreal < 0.0:
|
|
||||||
if abs(x) < 0.5:
|
|
||||||
v = log(gamma(x))
|
|
||||||
if ximag == 0:
|
|
||||||
v = v.conjugate()
|
|
||||||
return v
|
|
||||||
z = 1-x
|
|
||||||
try:
|
|
||||||
re = z.real
|
|
||||||
im = z.imag
|
|
||||||
except AttributeError: # py2.5
|
|
||||||
re = z
|
|
||||||
im = 0.0
|
|
||||||
refloor = floor(re)
|
|
||||||
imsign = cmp(im, 0)
|
|
||||||
return (-pi*1j)*abs(refloor)*(1-abs(imsign)) + logpi - \
|
|
||||||
log(sinpi(z-refloor)) - loggamma(z) + 1j*pi*refloor*imsign
|
|
||||||
if x == 1.0 or x == 2.0:
|
|
||||||
return x*0
|
|
||||||
p = 0.
|
|
||||||
while abs(x) < 11:
|
|
||||||
p -= log(x)
|
|
||||||
x += 1.0
|
|
||||||
s = 0.918938533204672742 + (x-0.5)*log(x) - x
|
|
||||||
r = 1./x
|
|
||||||
r2 = r*r
|
|
||||||
s += 0.083333333333333333333*r; r *= r2
|
|
||||||
s += -0.0027777777777777777778*r; r *= r2
|
|
||||||
s += 0.00079365079365079365079*r; r *= r2
|
|
||||||
s += -0.0005952380952380952381*r; r *= r2
|
|
||||||
s += 0.00084175084175084175084*r; r *= r2
|
|
||||||
s += -0.0019175269175269175269*r; r *= r2
|
|
||||||
s += 0.0064102564102564102564*r; r *= r2
|
|
||||||
s += -0.02955065359477124183*r
|
|
||||||
return s + p
|
|
||||||
|
|
||||||
_psi_coeff = [
|
|
||||||
0.083333333333333333333,
|
|
||||||
-0.0083333333333333333333,
|
|
||||||
0.003968253968253968254,
|
|
||||||
-0.0041666666666666666667,
|
|
||||||
0.0075757575757575757576,
|
|
||||||
-0.021092796092796092796,
|
|
||||||
0.083333333333333333333,
|
|
||||||
-0.44325980392156862745,
|
|
||||||
3.0539543302701197438,
|
|
||||||
-26.456212121212121212]
|
|
||||||
|
|
||||||
def _digamma_real(x):
|
|
||||||
_intx = int(x)
|
|
||||||
if _intx == x:
|
|
||||||
if _intx <= 0:
|
|
||||||
raise ZeroDivisionError("polygamma pole")
|
|
||||||
if x < 0.5:
|
|
||||||
x = 1.0-x
|
|
||||||
s = pi*cotpi(x)
|
|
||||||
else:
|
|
||||||
s = 0.0
|
|
||||||
while x < 10.0:
|
|
||||||
s -= 1.0/x
|
|
||||||
x += 1.0
|
|
||||||
x2 = x**-2
|
|
||||||
t = x2
|
|
||||||
for c in _psi_coeff:
|
|
||||||
s -= c*t
|
|
||||||
if t < 1e-20:
|
|
||||||
break
|
|
||||||
t *= x2
|
|
||||||
return s + math.log(x) - 0.5/x
|
|
||||||
|
|
||||||
def _digamma_complex(x):
|
|
||||||
if not x.imag:
|
|
||||||
return complex(_digamma_real(x.real))
|
|
||||||
if x.real < 0.5:
|
|
||||||
x = 1.0-x
|
|
||||||
s = pi*cotpi(x)
|
|
||||||
else:
|
|
||||||
s = 0.0
|
|
||||||
while abs(x) < 10.0:
|
|
||||||
s -= 1.0/x
|
|
||||||
x += 1.0
|
|
||||||
x2 = x**-2
|
|
||||||
t = x2
|
|
||||||
for c in _psi_coeff:
|
|
||||||
s -= c*t
|
|
||||||
if abs(t) < 1e-20:
|
|
||||||
break
|
|
||||||
t *= x2
|
|
||||||
return s + cmath.log(x) - 0.5/x
|
|
||||||
|
|
||||||
digamma = _mathfun_real(_digamma_real, _digamma_complex)
|
|
||||||
|
|
||||||
# TODO: could implement complex erf and erfc here. Need
|
|
||||||
# to find an accurate method (avoiding cancellation)
|
|
||||||
# for approx. 1 < abs(x) < 9.
|
|
||||||
|
|
||||||
_erfc_coeff_P = [
|
|
||||||
1.0000000161203922312,
|
|
||||||
2.1275306946297962644,
|
|
||||||
2.2280433377390253297,
|
|
||||||
1.4695509105618423961,
|
|
||||||
0.66275911699770787537,
|
|
||||||
0.20924776504163751585,
|
|
||||||
0.045459713768411264339,
|
|
||||||
0.0063065951710717791934,
|
|
||||||
0.00044560259661560421715][::-1]
|
|
||||||
|
|
||||||
_erfc_coeff_Q = [
|
|
||||||
1.0000000000000000000,
|
|
||||||
3.2559100272784894318,
|
|
||||||
4.9019435608903239131,
|
|
||||||
4.4971472894498014205,
|
|
||||||
2.7845640601891186528,
|
|
||||||
1.2146026030046904138,
|
|
||||||
0.37647108453729465912,
|
|
||||||
0.080970149639040548613,
|
|
||||||
0.011178148899483545902,
|
|
||||||
0.00078981003831980423513][::-1]
|
|
||||||
|
|
||||||
def _polyval(coeffs, x):
|
|
||||||
p = coeffs[0]
|
|
||||||
for c in coeffs[1:]:
|
|
||||||
p = c + x*p
|
|
||||||
return p
|
|
||||||
|
|
||||||
def _erf_taylor(x):
|
|
||||||
# Taylor series assuming 0 <= x <= 1
|
|
||||||
x2 = x*x
|
|
||||||
s = t = x
|
|
||||||
n = 1
|
|
||||||
while abs(t) > 1e-17:
|
|
||||||
t *= x2/n
|
|
||||||
s -= t/(n+n+1)
|
|
||||||
n += 1
|
|
||||||
t *= x2/n
|
|
||||||
s += t/(n+n+1)
|
|
||||||
n += 1
|
|
||||||
return 1.1283791670955125739*s
|
|
||||||
|
|
||||||
def _erfc_mid(x):
|
|
||||||
# Rational approximation assuming 0 <= x <= 9
|
|
||||||
return exp(-x*x)*_polyval(_erfc_coeff_P,x)/_polyval(_erfc_coeff_Q,x)
|
|
||||||
|
|
||||||
def _erfc_asymp(x):
|
|
||||||
# Asymptotic expansion assuming x >= 9
|
|
||||||
x2 = x*x
|
|
||||||
v = exp(-x2)/x*0.56418958354775628695
|
|
||||||
r = t = 0.5 / x2
|
|
||||||
s = 1.0
|
|
||||||
for n in range(1,22,4):
|
|
||||||
s -= t
|
|
||||||
t *= r * (n+2)
|
|
||||||
s += t
|
|
||||||
t *= r * (n+4)
|
|
||||||
if abs(t) < 1e-17:
|
|
||||||
break
|
|
||||||
return s * v
|
|
||||||
|
|
||||||
def erf(x):
|
|
||||||
"""
|
|
||||||
erf of a real number.
|
|
||||||
"""
|
|
||||||
x = float(x)
|
|
||||||
if x != x:
|
|
||||||
return x
|
|
||||||
if x < 0.0:
|
|
||||||
return -erf(-x)
|
|
||||||
if x >= 1.0:
|
|
||||||
if x >= 6.0:
|
|
||||||
return 1.0
|
|
||||||
return 1.0 - _erfc_mid(x)
|
|
||||||
return _erf_taylor(x)
|
|
||||||
|
|
||||||
def erfc(x):
|
|
||||||
"""
|
|
||||||
erfc of a real number.
|
|
||||||
"""
|
|
||||||
x = float(x)
|
|
||||||
if x != x:
|
|
||||||
return x
|
|
||||||
if x < 0.0:
|
|
||||||
if x < -6.0:
|
|
||||||
return 2.0
|
|
||||||
return 2.0-erfc(-x)
|
|
||||||
if x > 9.0:
|
|
||||||
return _erfc_asymp(x)
|
|
||||||
if x >= 1.0:
|
|
||||||
return _erfc_mid(x)
|
|
||||||
return 1.0 - _erf_taylor(x)
|
|
||||||
|
|
||||||
gauss42 = [\
|
|
||||||
(0.99839961899006235, 0.0041059986046490839),
|
|
||||||
(-0.99839961899006235, 0.0041059986046490839),
|
|
||||||
(0.9915772883408609, 0.009536220301748501),
|
|
||||||
(-0.9915772883408609,0.009536220301748501),
|
|
||||||
(0.97934250806374812, 0.014922443697357493),
|
|
||||||
(-0.97934250806374812, 0.014922443697357493),
|
|
||||||
(0.96175936533820439,0.020227869569052644),
|
|
||||||
(-0.96175936533820439, 0.020227869569052644),
|
|
||||||
(0.93892355735498811, 0.025422959526113047),
|
|
||||||
(-0.93892355735498811,0.025422959526113047),
|
|
||||||
(0.91095972490412735, 0.030479240699603467),
|
|
||||||
(-0.91095972490412735, 0.030479240699603467),
|
|
||||||
(0.87802056981217269,0.03536907109759211),
|
|
||||||
(-0.87802056981217269, 0.03536907109759211),
|
|
||||||
(0.8402859832618168, 0.040065735180692258),
|
|
||||||
(-0.8402859832618168,0.040065735180692258),
|
|
||||||
(0.7979620532554873, 0.044543577771965874),
|
|
||||||
(-0.7979620532554873, 0.044543577771965874),
|
|
||||||
(0.75127993568948048,0.048778140792803244),
|
|
||||||
(-0.75127993568948048, 0.048778140792803244),
|
|
||||||
(0.70049459055617114, 0.052746295699174064),
|
|
||||||
(-0.70049459055617114,0.052746295699174064),
|
|
||||||
(0.64588338886924779, 0.056426369358018376),
|
|
||||||
(-0.64588338886924779, 0.056426369358018376),
|
|
||||||
(0.58774459748510932, 0.059798262227586649),
|
|
||||||
(-0.58774459748510932, 0.059798262227586649),
|
|
||||||
(0.5263957499311922, 0.062843558045002565),
|
|
||||||
(-0.5263957499311922, 0.062843558045002565),
|
|
||||||
(0.46217191207042191, 0.065545624364908975),
|
|
||||||
(-0.46217191207042191, 0.065545624364908975),
|
|
||||||
(0.39542385204297503, 0.067889703376521934),
|
|
||||||
(-0.39542385204297503, 0.067889703376521934),
|
|
||||||
(0.32651612446541151, 0.069862992492594159),
|
|
||||||
(-0.32651612446541151, 0.069862992492594159),
|
|
||||||
(0.25582507934287907, 0.071454714265170971),
|
|
||||||
(-0.25582507934287907, 0.071454714265170971),
|
|
||||||
(0.18373680656485453, 0.072656175243804091),
|
|
||||||
(-0.18373680656485453, 0.072656175243804091),
|
|
||||||
(0.11064502720851986, 0.073460813453467527),
|
|
||||||
(-0.11064502720851986, 0.073460813453467527),
|
|
||||||
(0.036948943165351772, 0.073864234232172879),
|
|
||||||
(-0.036948943165351772, 0.073864234232172879)]
|
|
||||||
|
|
||||||
EI_ASYMP_CONVERGENCE_RADIUS = 40.0
|
|
||||||
|
|
||||||
def ei_asymp(z, _e1=False):
|
|
||||||
r = 1./z
|
|
||||||
s = t = 1.0
|
|
||||||
k = 1
|
|
||||||
while 1:
|
|
||||||
t *= k*r
|
|
||||||
s += t
|
|
||||||
if abs(t) < 1e-16:
|
|
||||||
break
|
|
||||||
k += 1
|
|
||||||
v = s*exp(z)/z
|
|
||||||
if _e1:
|
|
||||||
if type(z) is complex:
|
|
||||||
zreal = z.real
|
|
||||||
zimag = z.imag
|
|
||||||
else:
|
|
||||||
zreal = z
|
|
||||||
zimag = 0.0
|
|
||||||
if zimag == 0.0 and zreal > 0.0:
|
|
||||||
v += pi*1j
|
|
||||||
else:
|
|
||||||
if type(z) is complex:
|
|
||||||
if z.imag > 0:
|
|
||||||
v += pi*1j
|
|
||||||
if z.imag < 0:
|
|
||||||
v -= pi*1j
|
|
||||||
return v
|
|
||||||
|
|
||||||
def ei_taylor(z, _e1=False):
|
|
||||||
s = t = z
|
|
||||||
k = 2
|
|
||||||
while 1:
|
|
||||||
t = t*z/k
|
|
||||||
term = t/k
|
|
||||||
if abs(term) < 1e-17:
|
|
||||||
break
|
|
||||||
s += term
|
|
||||||
k += 1
|
|
||||||
s += euler
|
|
||||||
if _e1:
|
|
||||||
s += log(-z)
|
|
||||||
else:
|
|
||||||
if type(z) is float or z.imag == 0.0:
|
|
||||||
s += math.log(abs(z))
|
|
||||||
else:
|
|
||||||
s += cmath.log(z)
|
|
||||||
return s
|
|
||||||
|
|
||||||
def ei(z, _e1=False):
|
|
||||||
typez = type(z)
|
|
||||||
if typez not in (float, complex):
|
|
||||||
try:
|
|
||||||
z = float(z)
|
|
||||||
typez = float
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
z = complex(z)
|
|
||||||
typez = complex
|
|
||||||
if not z:
|
|
||||||
return -INF
|
|
||||||
absz = abs(z)
|
|
||||||
if absz > EI_ASYMP_CONVERGENCE_RADIUS:
|
|
||||||
return ei_asymp(z, _e1)
|
|
||||||
elif absz <= 2.0 or (typez is float and z > 0.0):
|
|
||||||
return ei_taylor(z, _e1)
|
|
||||||
# Integrate, starting from whichever is smaller of a Taylor
|
|
||||||
# series value or an asymptotic series value
|
|
||||||
if typez is complex and z.real > 0.0:
|
|
||||||
zref = z / absz
|
|
||||||
ref = ei_taylor(zref, _e1)
|
|
||||||
else:
|
|
||||||
zref = EI_ASYMP_CONVERGENCE_RADIUS * z / absz
|
|
||||||
ref = ei_asymp(zref, _e1)
|
|
||||||
C = (zref-z)*0.5
|
|
||||||
D = (zref+z)*0.5
|
|
||||||
s = 0.0
|
|
||||||
if type(z) is complex:
|
|
||||||
_exp = cmath.exp
|
|
||||||
else:
|
|
||||||
_exp = math.exp
|
|
||||||
for x,w in gauss42:
|
|
||||||
t = C*x+D
|
|
||||||
s += w*_exp(t)/t
|
|
||||||
ref -= C*s
|
|
||||||
return ref
|
|
||||||
|
|
||||||
def e1(z):
|
|
||||||
# hack to get consistent signs if the imaginary part if 0
|
|
||||||
# and signed
|
|
||||||
typez = type(z)
|
|
||||||
if type(z) not in (float, complex):
|
|
||||||
try:
|
|
||||||
z = float(z)
|
|
||||||
typez = float
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
z = complex(z)
|
|
||||||
typez = complex
|
|
||||||
if typez is complex and not z.imag:
|
|
||||||
z = complex(z.real, 0.0)
|
|
||||||
# end hack
|
|
||||||
return -ei(-z, _e1=True)
|
|
||||||
|
|
||||||
_zeta_int = [\
|
|
||||||
-0.5,
|
|
||||||
0.0,
|
|
||||||
1.6449340668482264365,1.2020569031595942854,1.0823232337111381915,
|
|
||||||
1.0369277551433699263,1.0173430619844491397,1.0083492773819228268,
|
|
||||||
1.0040773561979443394,1.0020083928260822144,1.0009945751278180853,
|
|
||||||
1.0004941886041194646,1.0002460865533080483,1.0001227133475784891,
|
|
||||||
1.0000612481350587048,1.0000305882363070205,1.0000152822594086519,
|
|
||||||
1.0000076371976378998,1.0000038172932649998,1.0000019082127165539,
|
|
||||||
1.0000009539620338728,1.0000004769329867878,1.0000002384505027277,
|
|
||||||
1.0000001192199259653,1.0000000596081890513,1.0000000298035035147,
|
|
||||||
1.0000000149015548284]
|
|
||||||
|
|
||||||
_zeta_P = [-3.50000000087575873, -0.701274355654678147,
|
|
||||||
-0.0672313458590012612, -0.00398731457954257841,
|
|
||||||
-0.000160948723019303141, -4.67633010038383371e-6,
|
|
||||||
-1.02078104417700585e-7, -1.68030037095896287e-9,
|
|
||||||
-1.85231868742346722e-11][::-1]
|
|
||||||
|
|
||||||
_zeta_Q = [1.00000000000000000, -0.936552848762465319,
|
|
||||||
-0.0588835413263763741, -0.00441498861482948666,
|
|
||||||
-0.000143416758067432622, -5.10691659585090782e-6,
|
|
||||||
-9.58813053268913799e-8, -1.72963791443181972e-9,
|
|
||||||
-1.83527919681474132e-11][::-1]
|
|
||||||
|
|
||||||
_zeta_1 = [3.03768838606128127e-10, -1.21924525236601262e-8,
|
|
||||||
2.01201845887608893e-7, -1.53917240683468381e-6,
|
|
||||||
-5.09890411005967954e-7, 0.000122464707271619326,
|
|
||||||
-0.000905721539353130232, -0.00239315326074843037,
|
|
||||||
0.084239750013159168, 0.418938517907442414, 0.500000001921884009]
|
|
||||||
|
|
||||||
_zeta_0 = [-3.46092485016748794e-10, -6.42610089468292485e-9,
|
|
||||||
1.76409071536679773e-7, -1.47141263991560698e-6, -6.38880222546167613e-7,
|
|
||||||
0.000122641099800668209, -0.000905894913516772796, -0.00239303348507992713,
|
|
||||||
0.0842396947501199816, 0.418938533204660256, 0.500000000000000052]
|
|
||||||
|
|
||||||
def zeta(s):
|
|
||||||
"""
|
|
||||||
Riemann zeta function, real argument
|
|
||||||
"""
|
|
||||||
if not isinstance(s, (float, int)):
|
|
||||||
try:
|
|
||||||
s = float(s)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
try:
|
|
||||||
s = complex(s)
|
|
||||||
if not s.imag:
|
|
||||||
return complex(zeta(s.real))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
raise NotImplementedError
|
|
||||||
if s == 1:
|
|
||||||
raise ValueError("zeta(1) pole")
|
|
||||||
if s >= 27:
|
|
||||||
return 1.0 + 2.0**(-s) + 3.0**(-s)
|
|
||||||
n = int(s)
|
|
||||||
if n == s:
|
|
||||||
if n >= 0:
|
|
||||||
return _zeta_int[n]
|
|
||||||
if not (n % 2):
|
|
||||||
return 0.0
|
|
||||||
if s <= 0.0:
|
|
||||||
return 2.**s*pi**(s-1)*_sinpi_real(0.5*s)*_gamma_real(1-s)*zeta(1-s)
|
|
||||||
if s <= 2.0:
|
|
||||||
if s <= 1.0:
|
|
||||||
return _polyval(_zeta_0,s)/(s-1)
|
|
||||||
return _polyval(_zeta_1,s)/(s-1)
|
|
||||||
z = _polyval(_zeta_P,s) / _polyval(_zeta_Q,s)
|
|
||||||
return 1.0 + 2.0**(-s) + 3.0**(-s) + 4.0**(-s)*z
|
|
||||||
|
|
@ -1,522 +0,0 @@
|
||||||
# TODO: should use diagonalization-based algorithms
|
|
||||||
|
|
||||||
class MatrixCalculusMethods:
|
|
||||||
|
|
||||||
def _exp_pade(ctx, a):
|
|
||||||
"""
|
|
||||||
Exponential of a matrix using Pade approximants.
|
|
||||||
|
|
||||||
See G. H. Golub, C. F. van Loan 'Matrix Computations',
|
|
||||||
third Ed., page 572
|
|
||||||
|
|
||||||
TODO:
|
|
||||||
- find a good estimate for q
|
|
||||||
- reduce the number of matrix multiplications to improve
|
|
||||||
performance
|
|
||||||
"""
|
|
||||||
def eps_pade(p):
|
|
||||||
return ctx.mpf(2)**(3-2*p) * \
|
|
||||||
ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
|
|
||||||
q = 4
|
|
||||||
extraq = 8
|
|
||||||
while 1:
|
|
||||||
if eps_pade(q) < ctx.eps:
|
|
||||||
break
|
|
||||||
q += 1
|
|
||||||
q += extraq
|
|
||||||
j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
|
|
||||||
extra = q
|
|
||||||
prec = ctx.prec
|
|
||||||
ctx.dps += extra + 3
|
|
||||||
try:
|
|
||||||
a = a/2**j
|
|
||||||
na = a.rows
|
|
||||||
den = ctx.eye(na)
|
|
||||||
num = ctx.eye(na)
|
|
||||||
x = ctx.eye(na)
|
|
||||||
c = ctx.mpf(1)
|
|
||||||
for k in range(1, q+1):
|
|
||||||
c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
|
|
||||||
x = a*x
|
|
||||||
cx = c*x
|
|
||||||
num += cx
|
|
||||||
den += (-1)**k * cx
|
|
||||||
f = ctx.lu_solve_mat(den, num)
|
|
||||||
for k in range(j):
|
|
||||||
f = f*f
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return f*1
|
|
||||||
|
|
||||||
def expm(ctx, A, method='taylor'):
|
|
||||||
r"""
|
|
||||||
Computes the matrix exponential of a square matrix `A`, which is defined
|
|
||||||
by the power series
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
|
|
||||||
|
|
||||||
With method='taylor', the matrix exponential is computed
|
|
||||||
using the Taylor series. With method='pade', Pade approximants
|
|
||||||
are used instead.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Basic examples::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> expm(zeros(3))
|
|
||||||
[1.0 0.0 0.0]
|
|
||||||
[0.0 1.0 0.0]
|
|
||||||
[0.0 0.0 1.0]
|
|
||||||
>>> expm(eye(3))
|
|
||||||
[2.71828182845905 0.0 0.0]
|
|
||||||
[ 0.0 2.71828182845905 0.0]
|
|
||||||
[ 0.0 0.0 2.71828182845905]
|
|
||||||
>>> expm([[1,1,0],[1,0,1],[0,1,0]])
|
|
||||||
[ 3.86814500615414 2.26812870852145 0.841130841230196]
|
|
||||||
[ 2.26812870852145 2.44114713886289 1.42699786729125]
|
|
||||||
[0.841130841230196 1.42699786729125 1.6000162976327]
|
|
||||||
>>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
|
|
||||||
[ 3.86814500615414 2.26812870852145 0.841130841230196]
|
|
||||||
[ 2.26812870852145 2.44114713886289 1.42699786729125]
|
|
||||||
[0.841130841230196 1.42699786729125 1.6000162976327]
|
|
||||||
>>> expm([[1+j, 0], [1+j,1]])
|
|
||||||
[(1.46869393991589 + 2.28735528717884j) 0.0]
|
|
||||||
[ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
|
|
||||||
|
|
||||||
Matrices with large entries are allowed::
|
|
||||||
|
|
||||||
>>> expm(matrix([[1,2],[2,3]])**25)
|
|
||||||
[5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
|
|
||||||
[9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
|
|
||||||
|
|
||||||
The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
|
|
||||||
noncommuting matrices::
|
|
||||||
|
|
||||||
>>> A = hilbert(3)
|
|
||||||
>>> B = A + eye(3)
|
|
||||||
>>> chop(mnorm(A*B - B*A))
|
|
||||||
0.0
|
|
||||||
>>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
|
|
||||||
0.0
|
|
||||||
>>> B = A + ones(3)
|
|
||||||
>>> mnorm(A*B - B*A)
|
|
||||||
1.8
|
|
||||||
>>> mnorm(expm(A+B) - expm(A)*expm(B))
|
|
||||||
42.0927851137247
|
|
||||||
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(A)
|
|
||||||
if method == 'pade':
|
|
||||||
return ctx._exp_pade(A)
|
|
||||||
prec = ctx.prec
|
|
||||||
j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
|
|
||||||
j += int(0.5*prec**0.5)
|
|
||||||
try:
|
|
||||||
ctx.prec += 10 + 2*j
|
|
||||||
tol = +ctx.eps
|
|
||||||
A = A/2**j
|
|
||||||
T = A
|
|
||||||
Y = A**0 + A
|
|
||||||
k = 2
|
|
||||||
while 1:
|
|
||||||
T *= A * (1/ctx.mpf(k))
|
|
||||||
if ctx.mnorm(T, 'inf') < tol:
|
|
||||||
break
|
|
||||||
Y += T
|
|
||||||
k += 1
|
|
||||||
for k in xrange(j):
|
|
||||||
Y = Y*Y
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
Y *= 1
|
|
||||||
return Y
|
|
||||||
|
|
||||||
def cosm(ctx, A):
|
|
||||||
r"""
|
|
||||||
Gives the cosine of a square matrix `A`, defined in analogy
|
|
||||||
with the matrix exponential.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> X = eye(3)
|
|
||||||
>>> cosm(X)
|
|
||||||
[0.54030230586814 0.0 0.0]
|
|
||||||
[ 0.0 0.54030230586814 0.0]
|
|
||||||
[ 0.0 0.0 0.54030230586814]
|
|
||||||
>>> X = hilbert(3)
|
|
||||||
>>> cosm(X)
|
|
||||||
[ 0.424403834569555 -0.316643413047167 -0.221474945949293]
|
|
||||||
[-0.316643413047167 0.820646708837824 -0.127183694770039]
|
|
||||||
[-0.221474945949293 -0.127183694770039 0.909236687217541]
|
|
||||||
>>> X = matrix([[1+j,-2],[0,-j]])
|
|
||||||
>>> cosm(X)
|
|
||||||
[(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
|
|
||||||
[ 0.0 (1.54308063481524 + 0.0j)]
|
|
||||||
"""
|
|
||||||
B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
|
|
||||||
if not sum(A.apply(ctx.im).apply(abs)):
|
|
||||||
B = B.apply(ctx.re)
|
|
||||||
return B
|
|
||||||
|
|
||||||
def sinm(ctx, A):
|
|
||||||
r"""
|
|
||||||
Gives the sine of a square matrix `A`, defined in analogy
|
|
||||||
with the matrix exponential.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> X = eye(3)
|
|
||||||
>>> sinm(X)
|
|
||||||
[0.841470984807897 0.0 0.0]
|
|
||||||
[ 0.0 0.841470984807897 0.0]
|
|
||||||
[ 0.0 0.0 0.841470984807897]
|
|
||||||
>>> X = hilbert(3)
|
|
||||||
>>> sinm(X)
|
|
||||||
[0.711608512150994 0.339783913247439 0.220742837314741]
|
|
||||||
[0.339783913247439 0.244113865695532 0.187231271174372]
|
|
||||||
[0.220742837314741 0.187231271174372 0.155816730769635]
|
|
||||||
>>> X = matrix([[1+j,-2],[0,-j]])
|
|
||||||
>>> sinm(X)
|
|
||||||
[(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
|
|
||||||
[ 0.0 (0.0 - 1.1752011936438j)]
|
|
||||||
"""
|
|
||||||
B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
|
|
||||||
if not sum(A.apply(ctx.im).apply(abs)):
|
|
||||||
B = B.apply(ctx.re)
|
|
||||||
return B
|
|
||||||
|
|
||||||
def _sqrtm_rot(ctx, A, _may_rotate):
|
|
||||||
# If the iteration fails to converge, cheat by performing
|
|
||||||
# a rotation by a complex number
|
|
||||||
u = ctx.j**0.3
|
|
||||||
return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
|
|
||||||
|
|
||||||
def sqrtm(ctx, A, _may_rotate=2):
|
|
||||||
r"""
|
|
||||||
Computes a square root of the square matrix `A`, i.e. returns
|
|
||||||
a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
|
|
||||||
of a matrix, if it exists, is not unique.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Square roots of some simple matrices::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> sqrtm([[1,0], [0,1]])
|
|
||||||
[1.0 0.0]
|
|
||||||
[0.0 1.0]
|
|
||||||
>>> sqrtm([[0,0], [0,0]])
|
|
||||||
[0.0 0.0]
|
|
||||||
[0.0 0.0]
|
|
||||||
>>> sqrtm([[2,0],[0,1]])
|
|
||||||
[1.4142135623731 0.0]
|
|
||||||
[ 0.0 1.0]
|
|
||||||
>>> sqrtm([[1,1],[1,0]])
|
|
||||||
[ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
|
|
||||||
[(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
|
|
||||||
>>> sqrtm([[1,0],[0,1]])
|
|
||||||
[1.0 0.0]
|
|
||||||
[0.0 1.0]
|
|
||||||
>>> sqrtm([[-1,0],[0,1]])
|
|
||||||
[(0.0 - 1.0j) 0.0]
|
|
||||||
[ 0.0 (1.0 + 0.0j)]
|
|
||||||
>>> sqrtm([[j,0],[0,j]])
|
|
||||||
[(0.707106781186547 + 0.707106781186547j) 0.0]
|
|
||||||
[ 0.0 (0.707106781186547 + 0.707106781186547j)]
|
|
||||||
|
|
||||||
A square root of a rotation matrix, giving the corresponding
|
|
||||||
half-angle rotation matrix::
|
|
||||||
|
|
||||||
>>> t1 = 0.75
|
|
||||||
>>> t2 = t1 * 0.5
|
|
||||||
>>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
|
|
||||||
>>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
|
|
||||||
>>> sqrtm(A1)
|
|
||||||
[0.930507621912314 -0.366272529086048]
|
|
||||||
[0.366272529086048 0.930507621912314]
|
|
||||||
>>> A2
|
|
||||||
[0.930507621912314 -0.366272529086048]
|
|
||||||
[0.366272529086048 0.930507621912314]
|
|
||||||
|
|
||||||
The identity `(A^2)^{1/2} = A` does not necessarily hold::
|
|
||||||
|
|
||||||
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
|
|
||||||
>>> sqrtm(A**2)
|
|
||||||
[ 4.0 1.0 4.0]
|
|
||||||
[ 7.0 8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
>>> sqrtm(A)**2
|
|
||||||
[ 4.0 1.0 4.0]
|
|
||||||
[ 7.0 8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
>>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
|
|
||||||
>>> sqrtm(A**2)
|
|
||||||
[ 7.43715112194995 -0.324127569985474 1.8481718827526]
|
|
||||||
[-0.251549715716942 9.32699765900402 2.48221180985147]
|
|
||||||
[ 4.11609388833616 0.775751877098258 13.017955697342]
|
|
||||||
>>> chop(sqrtm(A)**2)
|
|
||||||
[-4.0 1.0 4.0]
|
|
||||||
[ 7.0 -8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
|
|
||||||
For some matrices, a square root does not exist::
|
|
||||||
|
|
||||||
>>> sqrtm([[0,1], [0,0]])
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
ZeroDivisionError: matrix is numerically singular
|
|
||||||
|
|
||||||
Two examples from the documentation for Matlab's ``sqrtm``::
|
|
||||||
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> sqrtm([[7,10],[15,22]])
|
|
||||||
[1.56669890360128 1.74077655955698]
|
|
||||||
[2.61116483933547 4.17786374293675]
|
|
||||||
>>>
|
|
||||||
>>> X = matrix(\
|
|
||||||
... [[5,-4,1,0,0],
|
|
||||||
... [-4,6,-4,1,0],
|
|
||||||
... [1,-4,6,-4,1],
|
|
||||||
... [0,1,-4,6,-4],
|
|
||||||
... [0,0,1,-4,5]])
|
|
||||||
>>> Y = matrix(\
|
|
||||||
... [[2,-1,-0,-0,-0],
|
|
||||||
... [-1,2,-1,0,-0],
|
|
||||||
... [0,-1,2,-1,0],
|
|
||||||
... [-0,0,-1,2,-1],
|
|
||||||
... [-0,-0,-0,-1,2]])
|
|
||||||
>>> mnorm(sqrtm(X) - Y)
|
|
||||||
4.53155328326114e-19
|
|
||||||
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(A)
|
|
||||||
# Trivial
|
|
||||||
if A*0 == A:
|
|
||||||
return A
|
|
||||||
prec = ctx.prec
|
|
||||||
if _may_rotate:
|
|
||||||
d = ctx.det(A)
|
|
||||||
if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
|
|
||||||
return ctx._sqrtm_rot(A, _may_rotate-1)
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
tol = ctx.eps * 128
|
|
||||||
Y = A
|
|
||||||
Z = I = A**0
|
|
||||||
k = 0
|
|
||||||
# Denman-Beavers iteration
|
|
||||||
while 1:
|
|
||||||
Yprev = Y
|
|
||||||
try:
|
|
||||||
Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
|
|
||||||
except ZeroDivisionError:
|
|
||||||
if _may_rotate:
|
|
||||||
Y = ctx._sqrtm_rot(A, _may_rotate-1)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
mag1 = ctx.mnorm(Y-Yprev, 'inf')
|
|
||||||
mag2 = ctx.mnorm(Y, 'inf')
|
|
||||||
if mag1 <= mag2*tol:
|
|
||||||
break
|
|
||||||
if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
|
|
||||||
return ctx._sqrtm_rot(A, _may_rotate-1)
|
|
||||||
k += 1
|
|
||||||
if k > ctx.prec:
|
|
||||||
raise ctx.NoConvergence
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
Y *= 1
|
|
||||||
return Y
|
|
||||||
|
|
||||||
def logm(ctx, A):
|
|
||||||
r"""
|
|
||||||
Computes a logarithm of the square matrix `A`, i.e. returns
|
|
||||||
a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
|
|
||||||
of a matrix, if it exists, is not unique.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Logarithms of some simple matrices::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> X = eye(3)
|
|
||||||
>>> logm(X)
|
|
||||||
[0.0 0.0 0.0]
|
|
||||||
[0.0 0.0 0.0]
|
|
||||||
[0.0 0.0 0.0]
|
|
||||||
>>> logm(2*X)
|
|
||||||
[0.693147180559945 0.0 0.0]
|
|
||||||
[ 0.0 0.693147180559945 0.0]
|
|
||||||
[ 0.0 0.0 0.693147180559945]
|
|
||||||
>>> logm(expm(X))
|
|
||||||
[1.0 0.0 0.0]
|
|
||||||
[0.0 1.0 0.0]
|
|
||||||
[0.0 0.0 1.0]
|
|
||||||
|
|
||||||
A logarithm of a complex matrix::
|
|
||||||
|
|
||||||
>>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
|
|
||||||
>>> B = logm(X)
|
|
||||||
>>> nprint(B)
|
|
||||||
[ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
|
|
||||||
[ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
|
|
||||||
[(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
|
|
||||||
>>> chop(expm(B))
|
|
||||||
[(2.0 + 1.0j) 1.0 3.0]
|
|
||||||
[(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
|
|
||||||
[ -4.0 -5.0 (0.0 + 1.0j)]
|
|
||||||
|
|
||||||
A matrix `X` close to the identity matrix, for which
|
|
||||||
`\log(\exp(X)) = \exp(\log(X)) = X` holds::
|
|
||||||
|
|
||||||
>>> X = eye(3) + hilbert(3)/4
|
|
||||||
>>> X
|
|
||||||
[ 1.25 0.125 0.0833333333333333]
|
|
||||||
[ 0.125 1.08333333333333 0.0625]
|
|
||||||
[0.0833333333333333 0.0625 1.05]
|
|
||||||
>>> logm(expm(X))
|
|
||||||
[ 1.25 0.125 0.0833333333333333]
|
|
||||||
[ 0.125 1.08333333333333 0.0625]
|
|
||||||
[0.0833333333333333 0.0625 1.05]
|
|
||||||
>>> expm(logm(X))
|
|
||||||
[ 1.25 0.125 0.0833333333333333]
|
|
||||||
[ 0.125 1.08333333333333 0.0625]
|
|
||||||
[0.0833333333333333 0.0625 1.05]
|
|
||||||
|
|
||||||
A logarithm of a rotation matrix, giving back the angle of
|
|
||||||
the rotation::
|
|
||||||
|
|
||||||
>>> t = 3.7
|
|
||||||
>>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
|
|
||||||
>>> chop(logm(A))
|
|
||||||
[ 0.0 -2.58318530717959]
|
|
||||||
[2.58318530717959 0.0]
|
|
||||||
>>> (2*pi-t)
|
|
||||||
2.58318530717959
|
|
||||||
|
|
||||||
For some matrices, a logarithm does not exist::
|
|
||||||
|
|
||||||
>>> logm([[1,0], [0,0]])
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
ZeroDivisionError: matrix is numerically singular
|
|
||||||
|
|
||||||
Logarithm of a matrix with large entries::
|
|
||||||
|
|
||||||
>>> logm(hilbert(3) * 10**20).apply(re)
|
|
||||||
[ 45.5597513593433 1.27721006042799 0.317662687717978]
|
|
||||||
[ 1.27721006042799 42.5222778973542 2.24003708791604]
|
|
||||||
[0.317662687717978 2.24003708791604 42.395212822267]
|
|
||||||
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(A)
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
tol = ctx.eps * 128
|
|
||||||
I = A**0
|
|
||||||
B = A
|
|
||||||
n = 0
|
|
||||||
while 1:
|
|
||||||
B = ctx.sqrtm(B)
|
|
||||||
n += 1
|
|
||||||
if ctx.mnorm(B-I, 'inf') < 0.125:
|
|
||||||
break
|
|
||||||
T = X = B-I
|
|
||||||
L = X*0
|
|
||||||
k = 1
|
|
||||||
while 1:
|
|
||||||
if k & 1:
|
|
||||||
L += T / k
|
|
||||||
else:
|
|
||||||
L -= T / k
|
|
||||||
T *= X
|
|
||||||
if ctx.mnorm(T, 'inf') < tol:
|
|
||||||
break
|
|
||||||
k += 1
|
|
||||||
if k > ctx.prec:
|
|
||||||
raise ctx.NoConvergence
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
L *= 2**n
|
|
||||||
return L
|
|
||||||
|
|
||||||
def powm(ctx, A, r):
|
|
||||||
r"""
|
|
||||||
Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
|
|
||||||
number `r`.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
Powers and inverse powers of a matrix::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = True
|
|
||||||
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
|
|
||||||
>>> powm(A, 2)
|
|
||||||
[ 63.0 20.0 69.0]
|
|
||||||
[174.0 89.0 199.0]
|
|
||||||
[164.0 48.0 179.0]
|
|
||||||
>>> chop(powm(powm(A, 4), 1/4.))
|
|
||||||
[ 4.0 1.0 4.0]
|
|
||||||
[ 7.0 8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
>>> powm(extraprec(20)(powm)(A, -4), -1/4.)
|
|
||||||
[ 4.0 1.0 4.0]
|
|
||||||
[ 7.0 8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
>>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
|
|
||||||
[ 4.0 1.0 4.0]
|
|
||||||
[ 7.0 8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
>>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
|
|
||||||
[ 4.0 1.0 4.0]
|
|
||||||
[ 7.0 8.0 9.0]
|
|
||||||
[10.0 2.0 11.0]
|
|
||||||
|
|
||||||
A Fibonacci-generating matrix::
|
|
||||||
|
|
||||||
>>> powm([[1,1],[1,0]], 10)
|
|
||||||
[89.0 55.0]
|
|
||||||
[55.0 34.0]
|
|
||||||
>>> fib(10)
|
|
||||||
55.0
|
|
||||||
>>> powm([[1,1],[1,0]], 6.5)
|
|
||||||
[(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
|
|
||||||
[(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
|
|
||||||
>>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
|
|
||||||
(10.2078589271083 - 0.0195927472575932j)
|
|
||||||
>>> powm([[1,1],[1,0]], 6.2)
|
|
||||||
[ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
|
|
||||||
[(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
|
|
||||||
>>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
|
|
||||||
(8.81733464837593 - 0.0133048601383712j)
|
|
||||||
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(A)
|
|
||||||
r = ctx.convert(r)
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
if ctx.isint(r):
|
|
||||||
v = A ** int(r)
|
|
||||||
elif ctx.isint(r*2):
|
|
||||||
y = int(r*2)
|
|
||||||
v = ctx.sqrtm(A) ** y
|
|
||||||
else:
|
|
||||||
v = ctx.expm(r*ctx.logm(A))
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
v *= 1
|
|
||||||
return v
|
|
||||||
|
|
@ -1,516 +0,0 @@
|
||||||
"""
|
|
||||||
Linear algebra
|
|
||||||
--------------
|
|
||||||
|
|
||||||
Linear equations
|
|
||||||
................
|
|
||||||
|
|
||||||
Basic linear algebra is implemented; you can for example solve the linear
|
|
||||||
equation system::
|
|
||||||
|
|
||||||
x + 2*y = -10
|
|
||||||
3*x + 4*y = 10
|
|
||||||
|
|
||||||
using ``lu_solve``::
|
|
||||||
|
|
||||||
>>> A = matrix([[1, 2], [3, 4]])
|
|
||||||
>>> b = matrix([-10, 10])
|
|
||||||
>>> x = lu_solve(A, b)
|
|
||||||
>>> x
|
|
||||||
matrix(
|
|
||||||
[['30.0'],
|
|
||||||
['-20.0']])
|
|
||||||
|
|
||||||
If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
|
|
||||||
|
|
||||||
>>> residual(A, x, b)
|
|
||||||
matrix(
|
|
||||||
[['3.46944695195361e-18'],
|
|
||||||
['3.46944695195361e-18']])
|
|
||||||
>>> str(eps)
|
|
||||||
'2.22044604925031e-16'
|
|
||||||
|
|
||||||
As you can see, the solution is quite accurate. The error is caused by the
|
|
||||||
inaccuracy of the internal floating point arithmetic. Though, it's even smaller
|
|
||||||
than the current machine epsilon, which basically means you can trust the
|
|
||||||
result.
|
|
||||||
|
|
||||||
If you need more speed, use NumPy. Or choose a faster data type using the
|
|
||||||
keyword ``force_type``::
|
|
||||||
|
|
||||||
>>> lu_solve(A, b, force_type=float)
|
|
||||||
matrix(
|
|
||||||
[[29.999999999999996],
|
|
||||||
[-19.999999999999996]])
|
|
||||||
|
|
||||||
``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
|
|
||||||
such systems, so the residual is minimized instead. Internally this is done
|
|
||||||
using Cholesky decomposition to compute a least squares approximation. This means
|
|
||||||
that that ``lu_solve`` will square the errors. If you can't afford this, use
|
|
||||||
``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
|
|
||||||
the residual automatically.
|
|
||||||
|
|
||||||
|
|
||||||
Matrix factorization
|
|
||||||
....................
|
|
||||||
|
|
||||||
The function ``lu`` computes an explicit LU factorization of a matrix::
|
|
||||||
|
|
||||||
>>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
|
|
||||||
>>> print P
|
|
||||||
[0.0 0.0 1.0]
|
|
||||||
[1.0 0.0 0.0]
|
|
||||||
[0.0 1.0 0.0]
|
|
||||||
>>> print L
|
|
||||||
[ 1.0 0.0 0.0]
|
|
||||||
[ 0.0 1.0 0.0]
|
|
||||||
[0.571428571428571 0.214285714285714 1.0]
|
|
||||||
>>> print U
|
|
||||||
[7.0 8.0 9.0]
|
|
||||||
[0.0 2.0 3.0]
|
|
||||||
[0.0 0.0 0.214285714285714]
|
|
||||||
>>> print P.T*L*U
|
|
||||||
[0.0 2.0 3.0]
|
|
||||||
[4.0 5.0 6.0]
|
|
||||||
[7.0 8.0 9.0]
|
|
||||||
|
|
||||||
Interval matrices
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
Matrices may contain interval elements. This allows one to perform
|
|
||||||
basic linear algebra operations such as matrix multiplication
|
|
||||||
and equation solving with rigorous error bounds::
|
|
||||||
|
|
||||||
>>> a = matrix([['0.1','0.3','1.0'],
|
|
||||||
... ['7.1','5.5','4.8'],
|
|
||||||
... ['3.2','4.4','5.6']], force_type=mpi)
|
|
||||||
>>>
|
|
||||||
>>> b = matrix(['4','0.6','0.5'], force_type=mpi)
|
|
||||||
>>> c = lu_solve(a, b)
|
|
||||||
>>> c
|
|
||||||
matrix(
|
|
||||||
[[[5.2582327113062393041, 5.2582327113062749951]],
|
|
||||||
[[-13.155049396267856583, -13.155049396267821167]],
|
|
||||||
[[7.4206915477497212555, 7.4206915477497310922]]])
|
|
||||||
>>> print a*c
|
|
||||||
[ [3.9999999999999866773, 4.0000000000000133227]]
|
|
||||||
[[0.59999999999972430942, 0.60000000000027142733]]
|
|
||||||
[[0.49999999999982236432, 0.50000000000018474111]]
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# *implement high-level qr()
|
|
||||||
# *test unitvector
|
|
||||||
# *iterative solving
|
|
||||||
|
|
||||||
from copy import copy
|
|
||||||
|
|
||||||
class LinearAlgebraMethods(object):
|
|
||||||
|
|
||||||
def LU_decomp(ctx, A, overwrite=False, use_cache=True):
|
|
||||||
"""
|
|
||||||
LU-factorization of a n*n matrix using the Gauss algorithm.
|
|
||||||
Returns L and U in one matrix and the pivot indices.
|
|
||||||
|
|
||||||
Use overwrite to specify whether A will be overwritten with L and U.
|
|
||||||
"""
|
|
||||||
if not A.rows == A.cols:
|
|
||||||
raise ValueError('need n*n matrix')
|
|
||||||
# get from cache if possible
|
|
||||||
if use_cache and isinstance(A, ctx.matrix) and A._LU:
|
|
||||||
return A._LU
|
|
||||||
if not overwrite:
|
|
||||||
orig = A
|
|
||||||
A = A.copy()
|
|
||||||
tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
|
|
||||||
n = A.rows
|
|
||||||
p = [None]*(n - 1)
|
|
||||||
for j in xrange(n - 1):
|
|
||||||
# pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
|
|
||||||
biggest = 0
|
|
||||||
for k in xrange(j, n):
|
|
||||||
s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
|
|
||||||
if ctx.absmin(s) <= tol:
|
|
||||||
raise ZeroDivisionError('matrix is numerically singular')
|
|
||||||
current = 1/s * ctx.absmin(A[k,j])
|
|
||||||
if current > biggest: # TODO: what if equal?
|
|
||||||
biggest = current
|
|
||||||
p[j] = k
|
|
||||||
# swap rows according to p
|
|
||||||
ctx.swap_row(A, j, p[j])
|
|
||||||
if ctx.absmin(A[j,j]) <= tol:
|
|
||||||
raise ZeroDivisionError('matrix is numerically singular')
|
|
||||||
# calculate elimination factors and add rows
|
|
||||||
for i in xrange(j + 1, n):
|
|
||||||
A[i,j] /= A[j,j]
|
|
||||||
for k in xrange(j + 1, n):
|
|
||||||
A[i,k] -= A[i,j]*A[j,k]
|
|
||||||
if ctx.absmin(A[n - 1,n - 1]) <= tol:
|
|
||||||
raise ZeroDivisionError('matrix is numerically singular')
|
|
||||||
# cache decomposition
|
|
||||||
if not overwrite and isinstance(orig, ctx.matrix):
|
|
||||||
orig._LU = (A, p)
|
|
||||||
return A, p
|
|
||||||
|
|
||||||
def L_solve(ctx, L, b, p=None):
|
|
||||||
"""
|
|
||||||
Solve the lower part of a LU factorized matrix for y.
|
|
||||||
"""
|
|
||||||
assert L.rows == L.cols, 'need n*n matrix'
|
|
||||||
n = L.rows
|
|
||||||
assert len(b) == n
|
|
||||||
b = copy(b)
|
|
||||||
if p: # swap b according to p
|
|
||||||
for k in xrange(0, len(p)):
|
|
||||||
ctx.swap_row(b, k, p[k])
|
|
||||||
# solve
|
|
||||||
for i in xrange(1, n):
|
|
||||||
for j in xrange(i):
|
|
||||||
b[i] -= L[i,j] * b[j]
|
|
||||||
return b
|
|
||||||
|
|
||||||
def U_solve(ctx, U, y):
|
|
||||||
"""
|
|
||||||
Solve the upper part of a LU factorized matrix for x.
|
|
||||||
"""
|
|
||||||
assert U.rows == U.cols, 'need n*n matrix'
|
|
||||||
n = U.rows
|
|
||||||
assert len(y) == n
|
|
||||||
x = copy(y)
|
|
||||||
for i in xrange(n - 1, -1, -1):
|
|
||||||
for j in xrange(i + 1, n):
|
|
||||||
x[i] -= U[i,j] * x[j]
|
|
||||||
x[i] /= U[i,i]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def lu_solve(ctx, A, b, **kwargs):
|
|
||||||
"""
|
|
||||||
Ax = b => x
|
|
||||||
|
|
||||||
Solve a determined or overdetermined linear equations system.
|
|
||||||
Fast LU decomposition is used, which is less accurate than QR decomposition
|
|
||||||
(especially for overdetermined systems), but it's twice as efficient.
|
|
||||||
Use qr_solve if you want more precision or have to solve a very ill-
|
|
||||||
conditioned system.
|
|
||||||
|
|
||||||
If you specify real=True, it does not check for overdeterminded complex
|
|
||||||
systems.
|
|
||||||
"""
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
# do not overwrite A nor b
|
|
||||||
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
|
||||||
if A.rows < A.cols:
|
|
||||||
raise ValueError('cannot solve underdetermined system')
|
|
||||||
if A.rows > A.cols:
|
|
||||||
# use least-squares method if overdetermined
|
|
||||||
# (this increases errors)
|
|
||||||
AH = A.H
|
|
||||||
A = AH * A
|
|
||||||
b = AH * b
|
|
||||||
if (kwargs.get('real', False) or
|
|
||||||
not sum(type(i) is ctx.mpc for i in A)):
|
|
||||||
# TODO: necessary to check also b?
|
|
||||||
x = ctx.cholesky_solve(A, b)
|
|
||||||
else:
|
|
||||||
x = ctx.lu_solve(A, b)
|
|
||||||
else:
|
|
||||||
# LU factorization
|
|
||||||
A, p = ctx.LU_decomp(A)
|
|
||||||
b = ctx.L_solve(A, b, p)
|
|
||||||
x = ctx.U_solve(A, b)
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return x
|
|
||||||
|
|
||||||
def improve_solution(ctx, A, x, b, maxsteps=1):
|
|
||||||
"""
|
|
||||||
Improve a solution to a linear equation system iteratively.
|
|
||||||
|
|
||||||
This re-uses the LU decomposition and is thus cheap.
|
|
||||||
Usually 3 up to 4 iterations are giving the maximal improvement.
|
|
||||||
"""
|
|
||||||
assert A.rows == A.cols, 'need n*n matrix' # TODO: really?
|
|
||||||
for _ in xrange(maxsteps):
|
|
||||||
r = ctx.residual(A, x, b)
|
|
||||||
if ctx.norm(r, 2) < 10*ctx.eps:
|
|
||||||
break
|
|
||||||
# this uses cached LU decomposition and is thus cheap
|
|
||||||
dx = ctx.lu_solve(A, -r)
|
|
||||||
x += dx
|
|
||||||
return x
|
|
||||||
|
|
||||||
def lu(ctx, A):
|
|
||||||
"""
|
|
||||||
A -> P, L, U
|
|
||||||
|
|
||||||
LU factorisation of a square matrix A. L is the lower, U the upper part.
|
|
||||||
P is the permutation matrix indicating the row swaps.
|
|
||||||
|
|
||||||
P*A = L*U
|
|
||||||
|
|
||||||
If you need efficiency, use the low-level method LU_decomp instead, it's
|
|
||||||
much more memory efficient.
|
|
||||||
"""
|
|
||||||
# get factorization
|
|
||||||
A, p = ctx.LU_decomp(A)
|
|
||||||
n = A.rows
|
|
||||||
L = ctx.matrix(n)
|
|
||||||
U = ctx.matrix(n)
|
|
||||||
for i in xrange(n):
|
|
||||||
for j in xrange(n):
|
|
||||||
if i > j:
|
|
||||||
L[i,j] = A[i,j]
|
|
||||||
elif i == j:
|
|
||||||
L[i,j] = 1
|
|
||||||
U[i,j] = A[i,j]
|
|
||||||
else:
|
|
||||||
U[i,j] = A[i,j]
|
|
||||||
# calculate permutation matrix
|
|
||||||
P = ctx.eye(n)
|
|
||||||
for k in xrange(len(p)):
|
|
||||||
ctx.swap_row(P, k, p[k])
|
|
||||||
return P, L, U
|
|
||||||
|
|
||||||
def unitvector(ctx, n, i):
|
|
||||||
"""
|
|
||||||
Return the i-th n-dimensional unit vector.
|
|
||||||
"""
|
|
||||||
assert 0 < i <= n, 'this unit vector does not exist'
|
|
||||||
return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
|
|
||||||
|
|
||||||
def inverse(ctx, A, **kwargs):
|
|
||||||
"""
|
|
||||||
Calculate the inverse of a matrix.
|
|
||||||
|
|
||||||
If you want to solve an equation system Ax = b, it's recommended to use
|
|
||||||
solve(A, b) instead, it's about 3 times more efficient.
|
|
||||||
"""
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec += 10
|
|
||||||
# do not overwrite A
|
|
||||||
A = ctx.matrix(A, **kwargs).copy()
|
|
||||||
n = A.rows
|
|
||||||
# get LU factorisation
|
|
||||||
A, p = ctx.LU_decomp(A)
|
|
||||||
cols = []
|
|
||||||
# calculate unit vectors and solve corresponding system to get columns
|
|
||||||
for i in xrange(1, n + 1):
|
|
||||||
e = ctx.unitvector(n, i)
|
|
||||||
y = ctx.L_solve(A, e, p)
|
|
||||||
cols.append(ctx.U_solve(A, y))
|
|
||||||
# convert columns to matrix
|
|
||||||
inv = []
|
|
||||||
for i in xrange(n):
|
|
||||||
row = []
|
|
||||||
for j in xrange(n):
|
|
||||||
row.append(cols[j][i])
|
|
||||||
inv.append(row)
|
|
||||||
result = ctx.matrix(inv, **kwargs)
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
return result
|
|
||||||
|
|
||||||
def householder(ctx, A):
|
|
||||||
"""
|
|
||||||
(A|b) -> H, p, x, res
|
|
||||||
|
|
||||||
(A|b) is the coefficient matrix with left hand side of an optionally
|
|
||||||
overdetermined linear equation system.
|
|
||||||
H and p contain all information about the transformation matrices.
|
|
||||||
x is the solution, res the residual.
|
|
||||||
"""
|
|
||||||
assert isinstance(A, ctx.matrix)
|
|
||||||
m = A.rows
|
|
||||||
n = A.cols
|
|
||||||
assert m >= n - 1
|
|
||||||
# calculate Householder matrix
|
|
||||||
p = []
|
|
||||||
for j in xrange(0, n - 1):
|
|
||||||
s = ctx.fsum((A[i,j])**2 for i in xrange(j, m))
|
|
||||||
if not abs(s) > ctx.eps:
|
|
||||||
raise ValueError('matrix is numerically singular')
|
|
||||||
p.append(-ctx.sign(A[j,j]) * ctx.sqrt(s))
|
|
||||||
kappa = ctx.one / (s - p[j] * A[j,j])
|
|
||||||
A[j,j] -= p[j]
|
|
||||||
for k in xrange(j+1, n):
|
|
||||||
y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j, m)) * kappa
|
|
||||||
for i in xrange(j, m):
|
|
||||||
A[i,k] -= A[i,j] * y
|
|
||||||
# solve Rx = c1
|
|
||||||
x = [A[i,n - 1] for i in xrange(n - 1)]
|
|
||||||
for i in xrange(n - 2, -1, -1):
|
|
||||||
x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
|
|
||||||
x[i] /= p[i]
|
|
||||||
# calculate residual
|
|
||||||
if not m == n - 1:
|
|
||||||
r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
|
|
||||||
else:
|
|
||||||
# determined system, residual should be 0
|
|
||||||
r = [0]*m # maybe a bad idea, changing r[i] will change all elements
|
|
||||||
return A, p, x, r
|
|
||||||
|
|
||||||
#def qr(ctx, A):
|
|
||||||
# """
|
|
||||||
# A -> Q, R
|
|
||||||
#
|
|
||||||
# QR factorisation of a square matrix A using Householder decomposition.
|
|
||||||
# Q is orthogonal, this leads to very few numerical errors.
|
|
||||||
#
|
|
||||||
# A = Q*R
|
|
||||||
# """
|
|
||||||
# H, p, x, res = householder(A)
|
|
||||||
# TODO: implement this
|
|
||||||
|
|
||||||
def residual(ctx, A, x, b, **kwargs):
|
|
||||||
"""
|
|
||||||
Calculate the residual of a solution to a linear equation system.
|
|
||||||
|
|
||||||
r = A*x - b for A*x = b
|
|
||||||
"""
|
|
||||||
oldprec = ctx.prec
|
|
||||||
try:
|
|
||||||
ctx.prec *= 2
|
|
||||||
A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
|
|
||||||
return A*x - b
|
|
||||||
finally:
|
|
||||||
ctx.prec = oldprec
|
|
||||||
|
|
||||||
def qr_solve(ctx, A, b, norm=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Ax = b => x, ||Ax - b||
|
|
||||||
|
|
||||||
Solve a determined or overdetermined linear equations system and
|
|
||||||
calculate the norm of the residual (error).
|
|
||||||
QR decomposition using Householder factorization is applied, which gives very
|
|
||||||
accurate results even for ill-conditioned matrices. qr_solve is twice as
|
|
||||||
efficient.
|
|
||||||
"""
|
|
||||||
if norm is None:
|
|
||||||
norm = ctx.norm
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
prec += 10
|
|
||||||
# do not overwrite A nor b
|
|
||||||
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
|
||||||
if A.rows < A.cols:
|
|
||||||
raise ValueError('cannot solve underdetermined system')
|
|
||||||
H, p, x, r = ctx.householder(ctx.extend(A, b))
|
|
||||||
res = ctx.norm(r)
|
|
||||||
# calculate residual "manually" for determined systems
|
|
||||||
if res == 0:
|
|
||||||
res = ctx.norm(ctx.residual(A, x, b))
|
|
||||||
return ctx.matrix(x, **kwargs), res
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
|
|
||||||
# TODO: possible for complex matrices? -> have a look at GSL
|
|
||||||
def cholesky(ctx, A):
|
|
||||||
"""
|
|
||||||
Cholesky decomposition of a symmetric positive-definite matrix.
|
|
||||||
|
|
||||||
Can be used to solve linear equation systems twice as efficient compared
|
|
||||||
to LU decomposition or to test whether A is positive-definite.
|
|
||||||
|
|
||||||
A = L * L.T
|
|
||||||
Only L (the lower part) is returned.
|
|
||||||
"""
|
|
||||||
assert isinstance(A, ctx.matrix)
|
|
||||||
if not A.rows == A.cols:
|
|
||||||
raise ValueError('need n*n matrix')
|
|
||||||
n = A.rows
|
|
||||||
L = ctx.matrix(n)
|
|
||||||
for j in xrange(n):
|
|
||||||
s = A[j,j] - ctx.fsum(L[j,k]**2 for k in xrange(j))
|
|
||||||
if s < ctx.eps:
|
|
||||||
raise ValueError('matrix not positive-definite')
|
|
||||||
L[j,j] = ctx.sqrt(s)
|
|
||||||
for i in xrange(j, n):
|
|
||||||
L[i,j] = (A[i,j] - ctx.fsum(L[i,k] * L[j,k] for k in xrange(j))) \
|
|
||||||
/ L[j,j]
|
|
||||||
return L
|
|
||||||
|
|
||||||
def cholesky_solve(ctx, A, b, **kwargs):
|
|
||||||
"""
|
|
||||||
Ax = b => x
|
|
||||||
|
|
||||||
Solve a symmetric positive-definite linear equation system.
|
|
||||||
This is twice as efficient as lu_solve.
|
|
||||||
|
|
||||||
Typical use cases:
|
|
||||||
* A.T*A
|
|
||||||
* Hessian matrix
|
|
||||||
* differential equations
|
|
||||||
"""
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
prec += 10
|
|
||||||
# do not overwrite A nor b
|
|
||||||
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
|
||||||
if A.rows != A.cols:
|
|
||||||
raise ValueError('can only solve determined system')
|
|
||||||
# Cholesky factorization
|
|
||||||
L = ctx.cholesky(A)
|
|
||||||
# solve
|
|
||||||
n = L.rows
|
|
||||||
assert len(b) == n
|
|
||||||
for i in xrange(n):
|
|
||||||
b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
|
|
||||||
b[i] /= L[i,i]
|
|
||||||
x = ctx.U_solve(L.T, b)
|
|
||||||
return x
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
|
|
||||||
def det(ctx, A):
|
|
||||||
"""
|
|
||||||
Calculate the determinant of a matrix.
|
|
||||||
"""
|
|
||||||
prec = ctx.prec
|
|
||||||
try:
|
|
||||||
# do not overwrite A
|
|
||||||
A = ctx.matrix(A).copy()
|
|
||||||
# use LU factorization to calculate determinant
|
|
||||||
try:
|
|
||||||
R, p = ctx.LU_decomp(A)
|
|
||||||
except ZeroDivisionError:
|
|
||||||
return 0
|
|
||||||
z = 1
|
|
||||||
for i, e in enumerate(p):
|
|
||||||
if i != e:
|
|
||||||
z *= -1
|
|
||||||
for i in xrange(A.rows):
|
|
||||||
z *= R[i,i]
|
|
||||||
return z
|
|
||||||
finally:
|
|
||||||
ctx.prec = prec
|
|
||||||
|
|
||||||
def cond(ctx, A, norm=None):
|
|
||||||
"""
|
|
||||||
Calculate the condition number of a matrix using a specified matrix norm.
|
|
||||||
|
|
||||||
The condition number estimates the sensitivity of a matrix to errors.
|
|
||||||
Example: small input errors for ill-conditioned coefficient matrices
|
|
||||||
alter the solution of the system dramatically.
|
|
||||||
|
|
||||||
For ill-conditioned matrices it's recommended to use qr_solve() instead
|
|
||||||
of lu_solve(). This does not help with input errors however, it just avoids
|
|
||||||
to add additional errors.
|
|
||||||
|
|
||||||
Definition: cond(A) = ||A|| * ||A**-1||
|
|
||||||
"""
|
|
||||||
if norm is None:
|
|
||||||
norm = lambda x: ctx.mnorm(x,1)
|
|
||||||
return norm(A) * norm(ctx.inverse(A))
|
|
||||||
|
|
||||||
def lu_solve_mat(ctx, a, b):
|
|
||||||
"""Solve a * x = b where a and b are matrices."""
|
|
||||||
r = ctx.matrix(a.rows, b.cols)
|
|
||||||
for i in range(b.cols):
|
|
||||||
c = ctx.lu_solve(a, b.column(i))
|
|
||||||
for j in range(len(c)):
|
|
||||||
r[j, i] = c[j]
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
@ -1,858 +0,0 @@
|
||||||
# TODO: interpret list as vectors (for multiplication)
|
|
||||||
|
|
||||||
rowsep = '\n'
|
|
||||||
colsep = ' '
|
|
||||||
|
|
||||||
class _matrix(object):
|
|
||||||
"""
|
|
||||||
Numerical matrix.
|
|
||||||
|
|
||||||
Specify the dimensions or the data as a nested list.
|
|
||||||
Elements default to zero.
|
|
||||||
Use a flat list to create a column vector easily.
|
|
||||||
|
|
||||||
By default, only mpf is used to store the data. You can specify another type
|
|
||||||
using force_type=type. It's possible to specify None.
|
|
||||||
Make sure force_type(force_type()) is fast.
|
|
||||||
|
|
||||||
Creating matrices
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
Matrices in mpmath are implemented using dictionaries. Only non-zero values
|
|
||||||
are stored, so it is cheap to represent sparse matrices.
|
|
||||||
|
|
||||||
The most basic way to create one is to use the ``matrix`` class directly.
|
|
||||||
You can create an empty matrix specifying the dimensions:
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> matrix(2)
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
>>> matrix(2, 3)
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0', '0.0'],
|
|
||||||
['0.0', '0.0', '0.0']])
|
|
||||||
|
|
||||||
Calling ``matrix`` with one dimension will create a square matrix.
|
|
||||||
|
|
||||||
To access the dimensions of a matrix, use the ``rows`` or ``cols`` keyword:
|
|
||||||
|
|
||||||
>>> A = matrix(3, 2)
|
|
||||||
>>> A
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
>>> A.rows
|
|
||||||
3
|
|
||||||
>>> A.cols
|
|
||||||
2
|
|
||||||
|
|
||||||
You can also change the dimension of an existing matrix. This will set the
|
|
||||||
new elements to 0. If the new dimension is smaller than before, the
|
|
||||||
concerning elements are discarded:
|
|
||||||
|
|
||||||
>>> A.rows = 2
|
|
||||||
>>> A
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
|
|
||||||
Internally ``mpmathify`` is used every time an element is set. This
|
|
||||||
is done using the syntax A[row,column], counting from 0:
|
|
||||||
|
|
||||||
>>> A = matrix(2)
|
|
||||||
>>> A[1,1] = 1 + 1j
|
|
||||||
>>> A
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '(1.0 + 1.0j)']])
|
|
||||||
|
|
||||||
You can use the keyword ``force_type`` to change the function which is
|
|
||||||
called on every new element:
|
|
||||||
|
|
||||||
>>> matrix(2, 5, force_type=int)
|
|
||||||
matrix(
|
|
||||||
[[0, 0, 0, 0, 0],
|
|
||||||
[0, 0, 0, 0, 0]])
|
|
||||||
|
|
||||||
A more comfortable way to create a matrix lets you use nested lists:
|
|
||||||
|
|
||||||
>>> matrix([[1, 2], [3, 4]])
|
|
||||||
matrix(
|
|
||||||
[['1.0', '2.0'],
|
|
||||||
['3.0', '4.0']])
|
|
||||||
|
|
||||||
If you want to preserve the type of the elements you can use
|
|
||||||
``force_type=None``:
|
|
||||||
|
|
||||||
>>> matrix([[1, 2.5], [1j, mpf(2)]], force_type=None)
|
|
||||||
matrix(
|
|
||||||
[[1, 2.5],
|
|
||||||
[1j, '2.0']])
|
|
||||||
|
|
||||||
Convenient advanced functions are available for creating various standard
|
|
||||||
matrices, see ``zeros``, ``ones``, ``diag``, ``eye``, ``randmatrix`` and
|
|
||||||
``hilbert``.
|
|
||||||
|
|
||||||
Vectors
|
|
||||||
.......
|
|
||||||
|
|
||||||
Vectors may also be represented by the ``matrix`` class (with rows = 1 or cols = 1).
|
|
||||||
For vectors there are some things which make life easier. A column vector can
|
|
||||||
be created using a flat list, a row vectors using an almost flat nested list::
|
|
||||||
|
|
||||||
>>> matrix([1, 2, 3])
|
|
||||||
matrix(
|
|
||||||
[['1.0'],
|
|
||||||
['2.0'],
|
|
||||||
['3.0']])
|
|
||||||
>>> matrix([[1, 2, 3]])
|
|
||||||
matrix(
|
|
||||||
[['1.0', '2.0', '3.0']])
|
|
||||||
|
|
||||||
Optionally vectors can be accessed like lists, using only a single index::
|
|
||||||
|
|
||||||
>>> x = matrix([1, 2, 3])
|
|
||||||
>>> x[1]
|
|
||||||
mpf('2.0')
|
|
||||||
>>> x[1,0]
|
|
||||||
mpf('2.0')
|
|
||||||
|
|
||||||
Other
|
|
||||||
.....
|
|
||||||
|
|
||||||
Like you probably expected, matrices can be printed::
|
|
||||||
|
|
||||||
>>> print randmatrix(3) # doctest:+SKIP
|
|
||||||
[ 0.782963853573023 0.802057689719883 0.427895717335467]
|
|
||||||
[0.0541876859348597 0.708243266653103 0.615134039977379]
|
|
||||||
[ 0.856151514955773 0.544759264818486 0.686210904770947]
|
|
||||||
|
|
||||||
Use ``nstr`` or ``nprint`` to specify the number of digits to print::
|
|
||||||
|
|
||||||
>>> nprint(randmatrix(5), 3) # doctest:+SKIP
|
|
||||||
[2.07e-1 1.66e-1 5.06e-1 1.89e-1 8.29e-1]
|
|
||||||
[6.62e-1 6.55e-1 4.47e-1 4.82e-1 2.06e-2]
|
|
||||||
[4.33e-1 7.75e-1 6.93e-2 2.86e-1 5.71e-1]
|
|
||||||
[1.01e-1 2.53e-1 6.13e-1 3.32e-1 2.59e-1]
|
|
||||||
[1.56e-1 7.27e-2 6.05e-1 6.67e-2 2.79e-1]
|
|
||||||
|
|
||||||
As matrices are mutable, you will need to copy them sometimes::
|
|
||||||
|
|
||||||
>>> A = matrix(2)
|
|
||||||
>>> A
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
>>> B = A.copy()
|
|
||||||
>>> B[0,0] = 1
|
|
||||||
>>> B
|
|
||||||
matrix(
|
|
||||||
[['1.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
>>> A
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
|
|
||||||
Finally, it is possible to convert a matrix to a nested list. This is very useful,
|
|
||||||
as most Python libraries involving matrices or arrays (namely NumPy or SymPy)
|
|
||||||
support this format::
|
|
||||||
|
|
||||||
>>> B.tolist()
|
|
||||||
[[mpf('1.0'), mpf('0.0')], [mpf('0.0'), mpf('0.0')]]
|
|
||||||
|
|
||||||
|
|
||||||
Matrix operations
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
You can add and subtract matrices of compatible dimensions::
|
|
||||||
|
|
||||||
>>> A = matrix([[1, 2], [3, 4]])
|
|
||||||
>>> B = matrix([[-2, 4], [5, 9]])
|
|
||||||
>>> A + B
|
|
||||||
matrix(
|
|
||||||
[['-1.0', '6.0'],
|
|
||||||
['8.0', '13.0']])
|
|
||||||
>>> A - B
|
|
||||||
matrix(
|
|
||||||
[['3.0', '-2.0'],
|
|
||||||
['-2.0', '-5.0']])
|
|
||||||
>>> A + ones(3) # doctest:+ELLIPSIS
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
ValueError: incompatible dimensions for addition
|
|
||||||
|
|
||||||
It is possible to multiply or add matrices and scalars. In the latter case the
|
|
||||||
operation will be done element-wise::
|
|
||||||
|
|
||||||
>>> A * 2
|
|
||||||
matrix(
|
|
||||||
[['2.0', '4.0'],
|
|
||||||
['6.0', '8.0']])
|
|
||||||
>>> A / 4
|
|
||||||
matrix(
|
|
||||||
[['0.25', '0.5'],
|
|
||||||
['0.75', '1.0']])
|
|
||||||
>>> A - 1
|
|
||||||
matrix(
|
|
||||||
[['0.0', '1.0'],
|
|
||||||
['2.0', '3.0']])
|
|
||||||
|
|
||||||
Of course you can perform matrix multiplication, if the dimensions are
|
|
||||||
compatible::
|
|
||||||
|
|
||||||
>>> A * B
|
|
||||||
matrix(
|
|
||||||
[['8.0', '22.0'],
|
|
||||||
['14.0', '48.0']])
|
|
||||||
>>> matrix([[1, 2, 3]]) * matrix([[-6], [7], [-2]])
|
|
||||||
matrix(
|
|
||||||
[['2.0']])
|
|
||||||
|
|
||||||
You can raise powers of square matrices::
|
|
||||||
|
|
||||||
>>> A**2
|
|
||||||
matrix(
|
|
||||||
[['7.0', '10.0'],
|
|
||||||
['15.0', '22.0']])
|
|
||||||
|
|
||||||
Negative powers will calculate the inverse::
|
|
||||||
|
|
||||||
>>> A**-1
|
|
||||||
matrix(
|
|
||||||
[['-2.0', '1.0'],
|
|
||||||
['1.5', '-0.5']])
|
|
||||||
>>> A * A**-1
|
|
||||||
matrix(
|
|
||||||
[['1.0', '1.0842021724855e-19'],
|
|
||||||
['-2.16840434497101e-19', '1.0']])
|
|
||||||
|
|
||||||
Matrix transposition is straightforward::
|
|
||||||
|
|
||||||
>>> A = ones(2, 3)
|
|
||||||
>>> A
|
|
||||||
matrix(
|
|
||||||
[['1.0', '1.0', '1.0'],
|
|
||||||
['1.0', '1.0', '1.0']])
|
|
||||||
>>> A.T
|
|
||||||
matrix(
|
|
||||||
[['1.0', '1.0'],
|
|
||||||
['1.0', '1.0'],
|
|
||||||
['1.0', '1.0']])
|
|
||||||
|
|
||||||
Norms
|
|
||||||
.....
|
|
||||||
|
|
||||||
Sometimes you need to know how "large" a matrix or vector is. Due to their
|
|
||||||
multidimensional nature it's not possible to compare them, but there are
|
|
||||||
several functions to map a matrix or a vector to a positive real number, the
|
|
||||||
so called norms.
|
|
||||||
|
|
||||||
For vectors the p-norm is intended, usually the 1-, the 2- and the oo-norm are
|
|
||||||
used.
|
|
||||||
|
|
||||||
>>> x = matrix([-10, 2, 100])
|
|
||||||
>>> norm(x, 1)
|
|
||||||
mpf('112.0')
|
|
||||||
>>> norm(x, 2)
|
|
||||||
mpf('100.5186549850325')
|
|
||||||
>>> norm(x, inf)
|
|
||||||
mpf('100.0')
|
|
||||||
|
|
||||||
Please note that the 2-norm is the most used one, though it is more expensive
|
|
||||||
to calculate than the 1- or oo-norm.
|
|
||||||
|
|
||||||
It is possible to generalize some vector norms to matrix norm::
|
|
||||||
|
|
||||||
>>> A = matrix([[1, -1000], [100, 50]])
|
|
||||||
>>> mnorm(A, 1)
|
|
||||||
mpf('1050.0')
|
|
||||||
>>> mnorm(A, inf)
|
|
||||||
mpf('1001.0')
|
|
||||||
>>> mnorm(A, 'F')
|
|
||||||
mpf('1006.2310867787777')
|
|
||||||
|
|
||||||
The last norm (the "Frobenius-norm") is an approximation for the 2-norm, which
|
|
||||||
is hard to calculate and not available. The Frobenius-norm lacks some
|
|
||||||
mathematical properties you might expect from a norm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self.__data = {}
|
|
||||||
# LU decompostion cache, this is useful when solving the same system
|
|
||||||
# multiple times, when calculating the inverse and when calculating the
|
|
||||||
# determinant
|
|
||||||
self._LU = None
|
|
||||||
convert = kwargs.get('force_type', self.ctx.convert)
|
|
||||||
if isinstance(args[0], (list, tuple)):
|
|
||||||
if isinstance(args[0][0], (list, tuple)):
|
|
||||||
# interpret nested list as matrix
|
|
||||||
A = args[0]
|
|
||||||
self.__rows = len(A)
|
|
||||||
self.__cols = len(A[0])
|
|
||||||
for i, row in enumerate(A):
|
|
||||||
for j, a in enumerate(row):
|
|
||||||
self[i, j] = convert(a)
|
|
||||||
else:
|
|
||||||
# interpret list as row vector
|
|
||||||
v = args[0]
|
|
||||||
self.__rows = len(v)
|
|
||||||
self.__cols = 1
|
|
||||||
for i, e in enumerate(v):
|
|
||||||
self[i, 0] = e
|
|
||||||
elif isinstance(args[0], int):
|
|
||||||
# create empty matrix of given dimensions
|
|
||||||
if len(args) == 1:
|
|
||||||
self.__rows = self.__cols = args[0]
|
|
||||||
else:
|
|
||||||
assert isinstance(args[1], int), 'expected int'
|
|
||||||
self.__rows = args[0]
|
|
||||||
self.__cols = args[1]
|
|
||||||
elif isinstance(args[0], _matrix):
|
|
||||||
A = args[0].copy()
|
|
||||||
self.__data = A._matrix__data
|
|
||||||
self.__rows = A._matrix__rows
|
|
||||||
self.__cols = A._matrix__cols
|
|
||||||
convert = kwargs.get('force_type', self.ctx.convert)
|
|
||||||
for i in xrange(A.__rows):
|
|
||||||
for j in xrange(A.__cols):
|
|
||||||
A[i,j] = convert(A[i,j])
|
|
||||||
elif hasattr(args[0], 'tolist'):
|
|
||||||
A = self.ctx.matrix(args[0].tolist())
|
|
||||||
self.__data = A._matrix__data
|
|
||||||
self.__rows = A._matrix__rows
|
|
||||||
self.__cols = A._matrix__cols
|
|
||||||
else:
|
|
||||||
raise TypeError('could not interpret given arguments')
|
|
||||||
|
|
||||||
def apply(self, f):
|
|
||||||
"""
|
|
||||||
Return a copy of self with the function `f` applied elementwise.
|
|
||||||
"""
|
|
||||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
new[i,j] = f(self[i,j])
|
|
||||||
return new
|
|
||||||
|
|
||||||
def __nstr__(self, n=None, **kwargs):
|
|
||||||
# Build table of string representations of the elements
|
|
||||||
res = []
|
|
||||||
# Track per-column max lengths for pretty alignment
|
|
||||||
maxlen = [0] * self.cols
|
|
||||||
for i in range(self.rows):
|
|
||||||
res.append([])
|
|
||||||
for j in range(self.cols):
|
|
||||||
if n:
|
|
||||||
string = self.ctx.nstr(self[i,j], n, **kwargs)
|
|
||||||
else:
|
|
||||||
string = str(self[i,j])
|
|
||||||
res[-1].append(string)
|
|
||||||
maxlen[j] = max(len(string), maxlen[j])
|
|
||||||
# Patch strings together
|
|
||||||
for i, row in enumerate(res):
|
|
||||||
for j, elem in enumerate(row):
|
|
||||||
# Pad each element up to maxlen so the columns line up
|
|
||||||
row[j] = elem.rjust(maxlen[j])
|
|
||||||
res[i] = "[" + colsep.join(row) + "]"
|
|
||||||
return rowsep.join(res)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.__nstr__()
|
|
||||||
|
|
||||||
def _toliststr(self, avoid_type=False):
|
|
||||||
"""
|
|
||||||
Create a list string from a matrix.
|
|
||||||
|
|
||||||
If avoid_type: avoid multiple 'mpf's.
|
|
||||||
"""
|
|
||||||
# XXX: should be something like self.ctx._types
|
|
||||||
typ = self.ctx.mpf
|
|
||||||
s = '['
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
s += '['
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
if not avoid_type or not isinstance(self[i,j], typ):
|
|
||||||
a = repr(self[i,j])
|
|
||||||
else:
|
|
||||||
a = "'" + str(self[i,j]) + "'"
|
|
||||||
s += a + ', '
|
|
||||||
s = s[:-2]
|
|
||||||
s += '],\n '
|
|
||||||
s = s[:-3]
|
|
||||||
s += ']'
|
|
||||||
return s
|
|
||||||
|
|
||||||
def tolist(self):
|
|
||||||
"""
|
|
||||||
Convert the matrix to a nested list.
|
|
||||||
"""
|
|
||||||
return [[self[i,j] for j in range(self.__cols)] for i in range(self.__rows)]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
if self.ctx.pretty:
|
|
||||||
return self.__str__()
|
|
||||||
s = 'matrix(\n'
|
|
||||||
s += self._toliststr(avoid_type=True) + ')'
|
|
||||||
return s
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
if type(key) is int:
|
|
||||||
# only sufficent for vectors
|
|
||||||
if self.__rows == 1:
|
|
||||||
key = (0, key)
|
|
||||||
elif self.__cols == 1:
|
|
||||||
key = (key, 0)
|
|
||||||
else:
|
|
||||||
raise IndexError('insufficient indices for matrix')
|
|
||||||
if key in self.__data:
|
|
||||||
return self.__data[key]
|
|
||||||
else:
|
|
||||||
if key[0] >= self.__rows or key[1] >= self.__cols:
|
|
||||||
raise IndexError('matrix index out of range')
|
|
||||||
return self.ctx.zero
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
if type(key) is int:
|
|
||||||
# only sufficent for vectors
|
|
||||||
if self.__rows == 1:
|
|
||||||
key = (0, key)
|
|
||||||
elif self.__cols == 1:
|
|
||||||
key = (key, 0)
|
|
||||||
else:
|
|
||||||
raise IndexError('insufficient indices for matrix')
|
|
||||||
if key[0] >= self.__rows or key[1] >= self.__cols:
|
|
||||||
raise IndexError('matrix index out of range')
|
|
||||||
value = self.ctx.convert(value)
|
|
||||||
if value: # only store non-zeros
|
|
||||||
self.__data[key] = value
|
|
||||||
elif key in self.__data:
|
|
||||||
del self.__data[key]
|
|
||||||
# TODO: maybe do this better, if the performance impact is significant
|
|
||||||
if self._LU:
|
|
||||||
self._LU = None
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
yield self[i,j]
|
|
||||||
|
|
||||||
def __mul__(self, other):
|
|
||||||
if isinstance(other, self.ctx.matrix):
|
|
||||||
# dot multiplication TODO: use Strassen's method?
|
|
||||||
if self.__cols != other.__rows:
|
|
||||||
raise ValueError('dimensions not compatible for multiplication')
|
|
||||||
new = self.ctx.matrix(self.__rows, other.__cols)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(other.__cols):
|
|
||||||
new[i, j] = self.ctx.fdot((self[i,k], other[k,j])
|
|
||||||
for k in xrange(other.__rows))
|
|
||||||
return new
|
|
||||||
else:
|
|
||||||
# try scalar multiplication
|
|
||||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
new[i, j] = other * self[i, j]
|
|
||||||
return new
|
|
||||||
|
|
||||||
def __rmul__(self, other):
|
|
||||||
# assume other is scalar and thus commutative
|
|
||||||
assert not isinstance(other, self.ctx.matrix)
|
|
||||||
return self.__mul__(other)
|
|
||||||
|
|
||||||
def __pow__(self, other):
|
|
||||||
# avoid cyclic import problems
|
|
||||||
#from linalg import inverse
|
|
||||||
if not isinstance(other, int):
|
|
||||||
raise ValueError('only integer exponents are supported')
|
|
||||||
if not self.__rows == self.__cols:
|
|
||||||
raise ValueError('only powers of square matrices are defined')
|
|
||||||
n = other
|
|
||||||
if n == 0:
|
|
||||||
return self.ctx.eye(self.__rows)
|
|
||||||
if n < 0:
|
|
||||||
n = -n
|
|
||||||
neg = True
|
|
||||||
else:
|
|
||||||
neg = False
|
|
||||||
i = n
|
|
||||||
y = 1
|
|
||||||
z = self.copy()
|
|
||||||
while i != 0:
|
|
||||||
if i % 2 == 1:
|
|
||||||
y = y * z
|
|
||||||
z = z*z
|
|
||||||
i = i // 2
|
|
||||||
if neg:
|
|
||||||
y = self.ctx.inverse(y)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def __div__(self, other):
|
|
||||||
# assume other is scalar and do element-wise divison
|
|
||||||
assert not isinstance(other, self.ctx.matrix)
|
|
||||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
new[i,j] = self[i,j] / other
|
|
||||||
return new
|
|
||||||
|
|
||||||
__truediv__ = __div__
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
if isinstance(other, self.ctx.matrix):
|
|
||||||
if not (self.__rows == other.__rows and self.__cols == other.__cols):
|
|
||||||
raise ValueError('incompatible dimensions for addition')
|
|
||||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
new[i,j] = self[i,j] + other[i,j]
|
|
||||||
return new
|
|
||||||
else:
|
|
||||||
# assume other is scalar and add element-wise
|
|
||||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
new[i,j] += self[i,j] + other
|
|
||||||
return new
|
|
||||||
|
|
||||||
def __radd__(self, other):
|
|
||||||
return self.__add__(other)
|
|
||||||
|
|
||||||
def __sub__(self, other):
|
|
||||||
if isinstance(other, self.ctx.matrix) and not (self.__rows == other.__rows
|
|
||||||
and self.__cols == other.__cols):
|
|
||||||
raise ValueError('incompatible dimensions for substraction')
|
|
||||||
return self.__add__(other * (-1))
|
|
||||||
|
|
||||||
def __neg__(self):
|
|
||||||
return (-1) * self
|
|
||||||
|
|
||||||
def __rsub__(self, other):
|
|
||||||
return -self + other
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return self.__rows == other.__rows and self.__cols == other.__cols \
|
|
||||||
and self.__data == other.__data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
if self.rows == 1:
|
|
||||||
return self.cols
|
|
||||||
elif self.cols == 1:
|
|
||||||
return self.rows
|
|
||||||
else:
|
|
||||||
return self.rows # do it like numpy
|
|
||||||
|
|
||||||
def __getrows(self):
|
|
||||||
return self.__rows
|
|
||||||
|
|
||||||
def __setrows(self, value):
|
|
||||||
for key in self.__data.copy().iterkeys():
|
|
||||||
if key[0] >= value:
|
|
||||||
del self.__data[key]
|
|
||||||
self.__rows = value
|
|
||||||
|
|
||||||
rows = property(__getrows, __setrows, doc='number of rows')
|
|
||||||
|
|
||||||
def __getcols(self):
|
|
||||||
return self.__cols
|
|
||||||
|
|
||||||
def __setcols(self, value):
|
|
||||||
for key in self.__data.copy().iterkeys():
|
|
||||||
if key[1] >= value:
|
|
||||||
del self.__data[key]
|
|
||||||
self.__cols = value
|
|
||||||
|
|
||||||
cols = property(__getcols, __setcols, doc='number of columns')
|
|
||||||
|
|
||||||
def transpose(self):
|
|
||||||
new = self.ctx.matrix(self.__cols, self.__rows)
|
|
||||||
for i in xrange(self.__rows):
|
|
||||||
for j in xrange(self.__cols):
|
|
||||||
new[j,i] = self[i,j]
|
|
||||||
return new
|
|
||||||
|
|
||||||
T = property(transpose)
|
|
||||||
|
|
||||||
def conjugate(self):
|
|
||||||
return self.apply(self.ctx.conj)
|
|
||||||
|
|
||||||
def transpose_conj(self):
|
|
||||||
return self.conjugate().transpose()
|
|
||||||
|
|
||||||
H = property(transpose_conj)
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
new = self.ctx.matrix(self.__rows, self.__cols)
|
|
||||||
new.__data = self.__data.copy()
|
|
||||||
return new
|
|
||||||
|
|
||||||
__copy__ = copy
|
|
||||||
|
|
||||||
def column(self, n):
|
|
||||||
m = self.ctx.matrix(self.rows, 1)
|
|
||||||
for i in range(self.rows):
|
|
||||||
m[i] = self[i,n]
|
|
||||||
return m
|
|
||||||
|
|
||||||
class MatrixMethods(object):
|
|
||||||
|
|
||||||
def __init__(ctx):
|
|
||||||
# XXX: subclass
|
|
||||||
ctx.matrix = type('matrix', (_matrix,), {})
|
|
||||||
ctx.matrix.ctx = ctx
|
|
||||||
ctx.matrix.convert = ctx.convert
|
|
||||||
|
|
||||||
def eye(ctx, n, **kwargs):
|
|
||||||
"""
|
|
||||||
Create square identity matrix n x n.
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(n, **kwargs)
|
|
||||||
for i in xrange(n):
|
|
||||||
A[i,i] = 1
|
|
||||||
return A
|
|
||||||
|
|
||||||
def diag(ctx, diagonal, **kwargs):
|
|
||||||
"""
|
|
||||||
Create square diagonal matrix using given list.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from mpmath import diag, mp
|
|
||||||
>>> mp.pretty = False
|
|
||||||
>>> diag([1, 2, 3])
|
|
||||||
matrix(
|
|
||||||
[['1.0', '0.0', '0.0'],
|
|
||||||
['0.0', '2.0', '0.0'],
|
|
||||||
['0.0', '0.0', '3.0']])
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(len(diagonal), **kwargs)
|
|
||||||
for i in xrange(len(diagonal)):
|
|
||||||
A[i,i] = diagonal[i]
|
|
||||||
return A
|
|
||||||
|
|
||||||
def zeros(ctx, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create matrix m x n filled with zeros.
|
|
||||||
One given dimension will create square matrix n x n.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from mpmath import zeros, mp
|
|
||||||
>>> mp.pretty = False
|
|
||||||
>>> zeros(2)
|
|
||||||
matrix(
|
|
||||||
[['0.0', '0.0'],
|
|
||||||
['0.0', '0.0']])
|
|
||||||
"""
|
|
||||||
if len(args) == 1:
|
|
||||||
m = n = args[0]
|
|
||||||
elif len(args) == 2:
|
|
||||||
m = args[0]
|
|
||||||
n = args[1]
|
|
||||||
else:
|
|
||||||
raise TypeError('zeros expected at most 2 arguments, got %i' % len(args))
|
|
||||||
A = ctx.matrix(m, n, **kwargs)
|
|
||||||
for i in xrange(m):
|
|
||||||
for j in xrange(n):
|
|
||||||
A[i,j] = 0
|
|
||||||
return A
|
|
||||||
|
|
||||||
def ones(ctx, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create matrix m x n filled with ones.
|
|
||||||
One given dimension will create square matrix n x n.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from mpmath import ones, mp
|
|
||||||
>>> mp.pretty = False
|
|
||||||
>>> ones(2)
|
|
||||||
matrix(
|
|
||||||
[['1.0', '1.0'],
|
|
||||||
['1.0', '1.0']])
|
|
||||||
"""
|
|
||||||
if len(args) == 1:
|
|
||||||
m = n = args[0]
|
|
||||||
elif len(args) == 2:
|
|
||||||
m = args[0]
|
|
||||||
n = args[1]
|
|
||||||
else:
|
|
||||||
raise TypeError('ones expected at most 2 arguments, got %i' % len(args))
|
|
||||||
A = ctx.matrix(m, n, **kwargs)
|
|
||||||
for i in xrange(m):
|
|
||||||
for j in xrange(n):
|
|
||||||
A[i,j] = 1
|
|
||||||
return A
|
|
||||||
|
|
||||||
def hilbert(ctx, m, n=None):
|
|
||||||
"""
|
|
||||||
Create (pseudo) hilbert matrix m x n.
|
|
||||||
One given dimension will create hilbert matrix n x n.
|
|
||||||
|
|
||||||
The matrix is very ill-conditioned and symmetric, positive definite if
|
|
||||||
square.
|
|
||||||
"""
|
|
||||||
if n is None:
|
|
||||||
n = m
|
|
||||||
A = ctx.matrix(m, n)
|
|
||||||
for i in xrange(m):
|
|
||||||
for j in xrange(n):
|
|
||||||
A[i,j] = ctx.one / (i + j + 1)
|
|
||||||
return A
|
|
||||||
|
|
||||||
def randmatrix(ctx, m, n=None, min=0, max=1, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a random m x n matrix.
|
|
||||||
|
|
||||||
All values are >= min and <max.
|
|
||||||
n defaults to m.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from mpmath import randmatrix
|
|
||||||
>>> randmatrix(2) # doctest:+SKIP
|
|
||||||
matrix(
|
|
||||||
[['0.53491598236191806', '0.57195669543302752'],
|
|
||||||
['0.85589992269513615', '0.82444367501382143']])
|
|
||||||
"""
|
|
||||||
if not n:
|
|
||||||
n = m
|
|
||||||
A = ctx.matrix(m, n, **kwargs)
|
|
||||||
for i in xrange(m):
|
|
||||||
for j in xrange(n):
|
|
||||||
A[i,j] = ctx.rand() * (max - min) + min
|
|
||||||
return A
|
|
||||||
|
|
||||||
def swap_row(ctx, A, i, j):
|
|
||||||
"""
|
|
||||||
Swap row i with row j.
|
|
||||||
"""
|
|
||||||
if i == j:
|
|
||||||
return
|
|
||||||
if isinstance(A, ctx.matrix):
|
|
||||||
for k in xrange(A.cols):
|
|
||||||
A[i,k], A[j,k] = A[j,k], A[i,k]
|
|
||||||
elif isinstance(A, list):
|
|
||||||
A[i], A[j] = A[j], A[i]
|
|
||||||
else:
|
|
||||||
raise TypeError('could not interpret type')
|
|
||||||
|
|
||||||
def extend(ctx, A, b):
|
|
||||||
"""
|
|
||||||
Extend matrix A with column b and return result.
|
|
||||||
"""
|
|
||||||
assert isinstance(A, ctx.matrix)
|
|
||||||
assert A.rows == len(b)
|
|
||||||
A = A.copy()
|
|
||||||
A.cols += 1
|
|
||||||
for i in xrange(A.rows):
|
|
||||||
A[i, A.cols-1] = b[i]
|
|
||||||
return A
|
|
||||||
|
|
||||||
def norm(ctx, x, p=2):
|
|
||||||
r"""
|
|
||||||
Gives the entrywise `p`-norm of an iterable *x*, i.e. the vector norm
|
|
||||||
`\left(\sum_k |x_k|^p\right)^{1/p}`, for any given `1 \le p \le \infty`.
|
|
||||||
|
|
||||||
Special cases:
|
|
||||||
|
|
||||||
If *x* is not iterable, this just returns ``absmax(x)``.
|
|
||||||
|
|
||||||
``p=1`` gives the sum of absolute values.
|
|
||||||
|
|
||||||
``p=2`` is the standard Euclidean vector norm.
|
|
||||||
|
|
||||||
``p=inf`` gives the magnitude of the largest element.
|
|
||||||
|
|
||||||
For *x* a matrix, ``p=2`` is the Frobenius norm.
|
|
||||||
For operator matrix norms, use :func:`mnorm` instead.
|
|
||||||
|
|
||||||
You can use the string 'inf' as well as float('inf') or mpf('inf')
|
|
||||||
to specify the infinity norm.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> x = matrix([-10, 2, 100])
|
|
||||||
>>> norm(x, 1)
|
|
||||||
mpf('112.0')
|
|
||||||
>>> norm(x, 2)
|
|
||||||
mpf('100.5186549850325')
|
|
||||||
>>> norm(x, inf)
|
|
||||||
mpf('100.0')
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
iter(x)
|
|
||||||
except TypeError:
|
|
||||||
return ctx.absmax(x)
|
|
||||||
if type(p) is not int:
|
|
||||||
p = ctx.convert(p)
|
|
||||||
if p == ctx.inf:
|
|
||||||
return max(ctx.absmax(i) for i in x)
|
|
||||||
elif p == 1:
|
|
||||||
return ctx.fsum(x, absolute=1)
|
|
||||||
elif p == 2:
|
|
||||||
return ctx.sqrt(ctx.fsum(x, absolute=1, squared=1))
|
|
||||||
elif p > 1:
|
|
||||||
return ctx.nthroot(ctx.fsum(abs(i)**p for i in x), p)
|
|
||||||
else:
|
|
||||||
raise ValueError('p has to be >= 1')
|
|
||||||
|
|
||||||
def mnorm(ctx, A, p=1):
|
|
||||||
r"""
|
|
||||||
Gives the matrix (operator) `p`-norm of A. Currently ``p=1`` and ``p=inf``
|
|
||||||
are supported:
|
|
||||||
|
|
||||||
``p=1`` gives the 1-norm (maximal column sum)
|
|
||||||
|
|
||||||
``p=inf`` gives the `\infty`-norm (maximal row sum).
|
|
||||||
You can use the string 'inf' as well as float('inf') or mpf('inf')
|
|
||||||
|
|
||||||
``p=2`` (not implemented) for a square matrix is the usual spectral
|
|
||||||
matrix norm, i.e. the largest singular value.
|
|
||||||
|
|
||||||
``p='f'`` (or 'F', 'fro', 'Frobenius, 'frobenius') gives the
|
|
||||||
Frobenius norm, which is the elementwise 2-norm. The Frobenius norm is an
|
|
||||||
approximation of the spectral norm and satisfies
|
|
||||||
|
|
||||||
.. math ::
|
|
||||||
|
|
||||||
\frac{1}{\sqrt{\mathrm{rank}(A)}} \|A\|_F \le \|A\|_2 \le \|A\|_F
|
|
||||||
|
|
||||||
The Frobenius norm lacks some mathematical properties that might
|
|
||||||
be expected of a norm.
|
|
||||||
|
|
||||||
For general elementwise `p`-norms, use :func:`norm` instead.
|
|
||||||
|
|
||||||
**Examples**
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 15; mp.pretty = False
|
|
||||||
>>> A = matrix([[1, -1000], [100, 50]])
|
|
||||||
>>> mnorm(A, 1)
|
|
||||||
mpf('1050.0')
|
|
||||||
>>> mnorm(A, inf)
|
|
||||||
mpf('1001.0')
|
|
||||||
>>> mnorm(A, 'F')
|
|
||||||
mpf('1006.2310867787777')
|
|
||||||
|
|
||||||
"""
|
|
||||||
A = ctx.matrix(A)
|
|
||||||
if type(p) is not int:
|
|
||||||
if type(p) is str and 'frobenius'.startswith(p.lower()):
|
|
||||||
return ctx.norm(A, 2)
|
|
||||||
p = ctx.convert(p)
|
|
||||||
m, n = A.rows, A.cols
|
|
||||||
if p == 1:
|
|
||||||
return max(ctx.fsum((A[i,j] for i in xrange(m)), absolute=1) for j in xrange(n))
|
|
||||||
elif p == ctx.inf:
|
|
||||||
return max(ctx.fsum((A[i,j] for j in xrange(n)), absolute=1) for i in xrange(m))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("matrix p-norm for arbitrary p")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
@ -1,106 +0,0 @@
|
||||||
# TODO: use gmpy.mpq when available?
|
|
||||||
|
|
||||||
class mpq(tuple):
|
|
||||||
"""
|
|
||||||
Rational number type, only intended for internal use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
def _mpmath_(self, prec, rounding):
|
|
||||||
# XXX
|
|
||||||
return mp.make_mpf(from_rational(self[0], self[1], prec, rounding))
|
|
||||||
#(mpf(self[0])/self[1])._mpf_
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __int__(self):
|
|
||||||
a, b = self
|
|
||||||
return a // b
|
|
||||||
|
|
||||||
def __abs__(self):
|
|
||||||
a, b = self
|
|
||||||
return mpq((abs(a), b))
|
|
||||||
|
|
||||||
def __neg__(self):
|
|
||||||
a, b = self
|
|
||||||
return mpq((-a, b))
|
|
||||||
|
|
||||||
def __nonzero__(self):
|
|
||||||
return bool(self[0])
|
|
||||||
|
|
||||||
def __cmp__(self, other):
|
|
||||||
if type(other) is int and self[1] == 1:
|
|
||||||
return cmp(self[0], other)
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
if isinstance(other, mpq):
|
|
||||||
a, b = self
|
|
||||||
c, d = other
|
|
||||||
return mpq((a*d+b*c, b*d))
|
|
||||||
if isinstance(other, (int, long)):
|
|
||||||
a, b = self
|
|
||||||
return mpq((a+b*other, b))
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
__radd__ = __add__
|
|
||||||
|
|
||||||
def __sub__(self, other):
|
|
||||||
if isinstance(other, mpq):
|
|
||||||
a, b = self
|
|
||||||
c, d = other
|
|
||||||
return mpq((a*d-b*c, b*d))
|
|
||||||
if isinstance(other, (int, long)):
|
|
||||||
a, b = self
|
|
||||||
return mpq((a-b*other, b))
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __rsub__(self, other):
|
|
||||||
if isinstance(other, mpq):
|
|
||||||
a, b = self
|
|
||||||
c, d = other
|
|
||||||
return mpq((b*c-a*d, b*d))
|
|
||||||
if isinstance(other, (int, long)):
|
|
||||||
a, b = self
|
|
||||||
return mpq((b*other-a, b))
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __mul__(self, other):
|
|
||||||
if isinstance(other, mpq):
|
|
||||||
a, b = self
|
|
||||||
c, d = other
|
|
||||||
return mpq((a*c, b*d))
|
|
||||||
if isinstance(other, (int, long)):
|
|
||||||
a, b = self
|
|
||||||
return mpq((a*other, b))
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __div__(self, other):
|
|
||||||
if isinstance(other, (int, long)):
|
|
||||||
if other:
|
|
||||||
a, b = self
|
|
||||||
return mpq((a, b*other))
|
|
||||||
raise ZeroDivisionError
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __pow__(self, other):
|
|
||||||
if type(other) is int:
|
|
||||||
a, b = self
|
|
||||||
return mpq((a**other, b**other))
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
__rmul__ = __mul__
|
|
||||||
|
|
||||||
|
|
||||||
mpq_1 = mpq((1,1))
|
|
||||||
mpq_0 = mpq((0,1))
|
|
||||||
mpq_1_2 = mpq((1,2))
|
|
||||||
mpq_3_2 = mpq((3,2))
|
|
||||||
mpq_1_4 = mpq((1,4))
|
|
||||||
mpq_1_16 = mpq((1,16))
|
|
||||||
mpq_3_16 = mpq((3,16))
|
|
||||||
mpq_5_2 = mpq((5,2))
|
|
||||||
mpq_3_4 = mpq((3,4))
|
|
||||||
mpq_7_4 = mpq((7,4))
|
|
||||||
mpq_5_4 = mpq((5,4))
|
|
||||||
|
|
||||||
|
|
@ -1,159 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
"""
|
|
||||||
python runtests.py -py
|
|
||||||
Use py.test to run tests (more useful for debugging)
|
|
||||||
|
|
||||||
python runtests.py -psyco
|
|
||||||
Enable psyco to make tests run about 50% faster
|
|
||||||
|
|
||||||
python runtests.py -coverage
|
|
||||||
Generate test coverage report. Statistics are written to /tmp
|
|
||||||
|
|
||||||
python runtests.py -profile
|
|
||||||
Generate profile stats (this is much slower)
|
|
||||||
|
|
||||||
python runtests.py -nogmpy
|
|
||||||
Run tests without using GMPY even if it exists
|
|
||||||
|
|
||||||
python runtests.py -strict
|
|
||||||
Enforce extra tests in normalize()
|
|
||||||
|
|
||||||
python runtests.py -local
|
|
||||||
Insert '../..' at the beginning of sys.path to use local mpmath
|
|
||||||
|
|
||||||
Additional arguments are used to filter the tests to run. Only files that have
|
|
||||||
one of the arguments in their name are executed.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys, os, traceback
|
|
||||||
|
|
||||||
if "-psyco" in sys.argv:
|
|
||||||
sys.argv.remove('-psyco')
|
|
||||||
import psyco
|
|
||||||
psyco.full()
|
|
||||||
|
|
||||||
profile = False
|
|
||||||
if "-profile" in sys.argv:
|
|
||||||
sys.argv.remove('-profile')
|
|
||||||
profile = True
|
|
||||||
|
|
||||||
coverage = False
|
|
||||||
if "-coverage" in sys.argv:
|
|
||||||
sys.argv.remove('-coverage')
|
|
||||||
coverage = True
|
|
||||||
|
|
||||||
if "-nogmpy" in sys.argv:
|
|
||||||
sys.argv.remove('-nogmpy')
|
|
||||||
os.environ['MPMATH_NOGMPY'] = 'Y'
|
|
||||||
|
|
||||||
if "-strict" in sys.argv:
|
|
||||||
sys.argv.remove('-strict')
|
|
||||||
os.environ['MPMATH_STRICT'] = 'Y'
|
|
||||||
|
|
||||||
if "-local" in sys.argv:
|
|
||||||
sys.argv.remove('-local')
|
|
||||||
importdir = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]),
|
|
||||||
'../..'))
|
|
||||||
else:
|
|
||||||
importdir = ''
|
|
||||||
|
|
||||||
# TODO: add a flag for this
|
|
||||||
testdir = ''
|
|
||||||
|
|
||||||
def testit(importdir='', testdir=''):
|
|
||||||
"""Run all tests in testdir while importing from importdir."""
|
|
||||||
if importdir:
|
|
||||||
sys.path.insert(1, importdir)
|
|
||||||
if testdir:
|
|
||||||
sys.path.insert(1, testdir)
|
|
||||||
import os.path
|
|
||||||
import mpmath
|
|
||||||
print "mpmath imported from", os.path.dirname(mpmath.__file__)
|
|
||||||
print "mpmath backend:", mpmath.libmp.backend.BACKEND
|
|
||||||
print "mpmath mp class:", repr(mpmath.mp)
|
|
||||||
print "mpmath version:", mpmath.__version__
|
|
||||||
print "Python version:", sys.version
|
|
||||||
print
|
|
||||||
if "-py" in sys.argv:
|
|
||||||
sys.argv.remove('-py')
|
|
||||||
import py
|
|
||||||
py.test.cmdline.main()
|
|
||||||
else:
|
|
||||||
import glob
|
|
||||||
from timeit import default_timer as clock
|
|
||||||
modules = []
|
|
||||||
args = sys.argv[1:]
|
|
||||||
# search for tests in directory of this file if not otherwise specified
|
|
||||||
if not testdir:
|
|
||||||
pattern = os.path.dirname(sys.argv[0])
|
|
||||||
else:
|
|
||||||
pattern = testdir
|
|
||||||
if pattern:
|
|
||||||
pattern += '/'
|
|
||||||
pattern += 'test*.py'
|
|
||||||
# look for tests (respecting specified filter)
|
|
||||||
for f in glob.glob(pattern):
|
|
||||||
name = os.path.splitext(os.path.basename(f))[0]
|
|
||||||
# If run as a script, only run tests given as args, if any are given
|
|
||||||
if args and __name__ == "__main__":
|
|
||||||
ok = False
|
|
||||||
for arg in args:
|
|
||||||
if arg in name:
|
|
||||||
ok = True
|
|
||||||
break
|
|
||||||
if not ok:
|
|
||||||
continue
|
|
||||||
module = __import__(name)
|
|
||||||
priority = module.__dict__.get('priority', 100)
|
|
||||||
if priority == 666:
|
|
||||||
modules = [[priority, name, module]]
|
|
||||||
break
|
|
||||||
modules.append([priority, name, module])
|
|
||||||
# execute tests
|
|
||||||
modules.sort()
|
|
||||||
tstart = clock()
|
|
||||||
for priority, name, module in modules:
|
|
||||||
print name
|
|
||||||
for f in sorted(module.__dict__.keys()):
|
|
||||||
if f.startswith('test_'):
|
|
||||||
if coverage and ('numpy' in f):
|
|
||||||
continue
|
|
||||||
print " ", f[5:].ljust(25),
|
|
||||||
t1 = clock()
|
|
||||||
try:
|
|
||||||
module.__dict__[f]()
|
|
||||||
except:
|
|
||||||
etype, evalue, trb = sys.exc_info()
|
|
||||||
if etype in (KeyboardInterrupt, SystemExit):
|
|
||||||
raise
|
|
||||||
print
|
|
||||||
print "TEST FAILED!"
|
|
||||||
print
|
|
||||||
traceback.print_exc()
|
|
||||||
t2 = clock()
|
|
||||||
print "ok", " ", ("%.7f" % (t2-t1)), "s"
|
|
||||||
tend = clock()
|
|
||||||
print
|
|
||||||
print "finished tests in", ("%.2f" % (tend-tstart)), "seconds"
|
|
||||||
# clean sys.path
|
|
||||||
if importdir:
|
|
||||||
sys.path.remove(importdir)
|
|
||||||
if testdir:
|
|
||||||
sys.path.remove(testdir)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
if profile:
|
|
||||||
import cProfile
|
|
||||||
cProfile.run("testit('%s', '%s')" % (importdir, testdir), sort=1)
|
|
||||||
elif coverage:
|
|
||||||
import trace
|
|
||||||
tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix],
|
|
||||||
trace=0, count=1)
|
|
||||||
tracer.run('testit(importdir, testdir)')
|
|
||||||
r = tracer.results()
|
|
||||||
r.write_results(show_missing=True, summary=True, coverdir="/tmp")
|
|
||||||
else:
|
|
||||||
testit(importdir, testdir)
|
|
||||||
|
|
||||||
|
|
@ -1,161 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_type_compare():
|
|
||||||
assert mpf(2) == mpc(2,0)
|
|
||||||
assert mpf(0) == mpc(0)
|
|
||||||
assert mpf(2) != mpc(2, 0.00001)
|
|
||||||
assert mpf(2) == 2.0
|
|
||||||
assert mpf(2) != 3.0
|
|
||||||
assert mpf(2) == 2
|
|
||||||
assert mpf(2) != '2.0'
|
|
||||||
assert mpc(2) != '2.0'
|
|
||||||
|
|
||||||
def test_add():
|
|
||||||
assert mpf(2.5) + mpf(3) == 5.5
|
|
||||||
assert mpf(2.5) + 3 == 5.5
|
|
||||||
assert mpf(2.5) + 3.0 == 5.5
|
|
||||||
assert 3 + mpf(2.5) == 5.5
|
|
||||||
assert 3.0 + mpf(2.5) == 5.5
|
|
||||||
assert (3+0j) + mpf(2.5) == 5.5
|
|
||||||
assert mpc(2.5) + mpf(3) == 5.5
|
|
||||||
assert mpc(2.5) + 3 == 5.5
|
|
||||||
assert mpc(2.5) + 3.0 == 5.5
|
|
||||||
assert mpc(2.5) + (3+0j) == 5.5
|
|
||||||
assert 3 + mpc(2.5) == 5.5
|
|
||||||
assert 3.0 + mpc(2.5) == 5.5
|
|
||||||
assert (3+0j) + mpc(2.5) == 5.5
|
|
||||||
|
|
||||||
def test_sub():
|
|
||||||
assert mpf(2.5) - mpf(3) == -0.5
|
|
||||||
assert mpf(2.5) - 3 == -0.5
|
|
||||||
assert mpf(2.5) - 3.0 == -0.5
|
|
||||||
assert 3 - mpf(2.5) == 0.5
|
|
||||||
assert 3.0 - mpf(2.5) == 0.5
|
|
||||||
assert (3+0j) - mpf(2.5) == 0.5
|
|
||||||
assert mpc(2.5) - mpf(3) == -0.5
|
|
||||||
assert mpc(2.5) - 3 == -0.5
|
|
||||||
assert mpc(2.5) - 3.0 == -0.5
|
|
||||||
assert mpc(2.5) - (3+0j) == -0.5
|
|
||||||
assert 3 - mpc(2.5) == 0.5
|
|
||||||
assert 3.0 - mpc(2.5) == 0.5
|
|
||||||
assert (3+0j) - mpc(2.5) == 0.5
|
|
||||||
|
|
||||||
def test_mul():
|
|
||||||
assert mpf(2.5) * mpf(3) == 7.5
|
|
||||||
assert mpf(2.5) * 3 == 7.5
|
|
||||||
assert mpf(2.5) * 3.0 == 7.5
|
|
||||||
assert 3 * mpf(2.5) == 7.5
|
|
||||||
assert 3.0 * mpf(2.5) == 7.5
|
|
||||||
assert (3+0j) * mpf(2.5) == 7.5
|
|
||||||
assert mpc(2.5) * mpf(3) == 7.5
|
|
||||||
assert mpc(2.5) * 3 == 7.5
|
|
||||||
assert mpc(2.5) * 3.0 == 7.5
|
|
||||||
assert mpc(2.5) * (3+0j) == 7.5
|
|
||||||
assert 3 * mpc(2.5) == 7.5
|
|
||||||
assert 3.0 * mpc(2.5) == 7.5
|
|
||||||
assert (3+0j) * mpc(2.5) == 7.5
|
|
||||||
|
|
||||||
def test_div():
|
|
||||||
assert mpf(6) / mpf(3) == 2.0
|
|
||||||
assert mpf(6) / 3 == 2.0
|
|
||||||
assert mpf(6) / 3.0 == 2.0
|
|
||||||
assert 6 / mpf(3) == 2.0
|
|
||||||
assert 6.0 / mpf(3) == 2.0
|
|
||||||
assert (6+0j) / mpf(3.0) == 2.0
|
|
||||||
assert mpc(6) / mpf(3) == 2.0
|
|
||||||
assert mpc(6) / 3 == 2.0
|
|
||||||
assert mpc(6) / 3.0 == 2.0
|
|
||||||
assert mpc(6) / (3+0j) == 2.0
|
|
||||||
assert 6 / mpc(3) == 2.0
|
|
||||||
assert 6.0 / mpc(3) == 2.0
|
|
||||||
assert (6+0j) / mpc(3) == 2.0
|
|
||||||
|
|
||||||
def test_pow():
|
|
||||||
assert mpf(6) ** mpf(3) == 216.0
|
|
||||||
assert mpf(6) ** 3 == 216.0
|
|
||||||
assert mpf(6) ** 3.0 == 216.0
|
|
||||||
assert 6 ** mpf(3) == 216.0
|
|
||||||
assert 6.0 ** mpf(3) == 216.0
|
|
||||||
assert (6+0j) ** mpf(3.0) == 216.0
|
|
||||||
assert mpc(6) ** mpf(3) == 216.0
|
|
||||||
assert mpc(6) ** 3 == 216.0
|
|
||||||
assert mpc(6) ** 3.0 == 216.0
|
|
||||||
assert mpc(6) ** (3+0j) == 216.0
|
|
||||||
assert 6 ** mpc(3) == 216.0
|
|
||||||
assert 6.0 ** mpc(3) == 216.0
|
|
||||||
assert (6+0j) ** mpc(3) == 216.0
|
|
||||||
|
|
||||||
def test_mixed_misc():
|
|
||||||
assert 1 + mpf(3) == mpf(3) + 1 == 4
|
|
||||||
assert 1 - mpf(3) == -(mpf(3) - 1) == -2
|
|
||||||
assert 3 * mpf(2) == mpf(2) * 3 == 6
|
|
||||||
assert 6 / mpf(2) == mpf(6) / 2 == 3
|
|
||||||
assert 1.0 + mpf(3) == mpf(3) + 1.0 == 4
|
|
||||||
assert 1.0 - mpf(3) == -(mpf(3) - 1.0) == -2
|
|
||||||
assert 3.0 * mpf(2) == mpf(2) * 3.0 == 6
|
|
||||||
assert 6.0 / mpf(2) == mpf(6) / 2.0 == 3
|
|
||||||
|
|
||||||
def test_add_misc():
|
|
||||||
mp.dps = 15
|
|
||||||
assert mpf(4) + mpf(-70) == -66
|
|
||||||
assert mpf(1) + mpf(1.1)/80 == 1 + 1.1/80
|
|
||||||
assert mpf((1, 10000000000)) + mpf(3) == mpf((1, 10000000000))
|
|
||||||
assert mpf(3) + mpf((1, 10000000000)) == mpf((1, 10000000000))
|
|
||||||
assert mpf((1, -10000000000)) + mpf(3) == mpf(3)
|
|
||||||
assert mpf(3) + mpf((1, -10000000000)) == mpf(3)
|
|
||||||
assert mpf(1) + 1e-15 != 1
|
|
||||||
assert mpf(1) + 1e-20 == 1
|
|
||||||
assert mpf(1.07e-22) + 0 == mpf(1.07e-22)
|
|
||||||
assert mpf(0) + mpf(1.07e-22) == mpf(1.07e-22)
|
|
||||||
|
|
||||||
def test_complex_misc():
|
|
||||||
# many more tests needed
|
|
||||||
assert 1 + mpc(2) == 3
|
|
||||||
assert not mpc(2).ae(2 + 1e-13)
|
|
||||||
assert mpc(2+1e-15j).ae(2)
|
|
||||||
|
|
||||||
def test_complex_zeros():
|
|
||||||
for a in [0,2]:
|
|
||||||
for b in [0,3]:
|
|
||||||
for c in [0,4]:
|
|
||||||
for d in [0,5]:
|
|
||||||
assert mpc(a,b)*mpc(c,d) == complex(a,b)*complex(c,d)
|
|
||||||
|
|
||||||
def test_hash():
|
|
||||||
for i in range(-256, 256):
|
|
||||||
assert hash(mpf(i)) == hash(i)
|
|
||||||
assert hash(mpf(0.5)) == hash(0.5)
|
|
||||||
assert hash(mpc(2,3)) == hash(2+3j)
|
|
||||||
# Check that this doesn't fail
|
|
||||||
assert hash(inf)
|
|
||||||
# Check that overflow doesn't assign equal hashes to large numbers
|
|
||||||
assert hash(mpf('1e1000')) != hash('1e10000')
|
|
||||||
assert hash(mpc(100,'1e1000')) != hash(mpc(200,'1e1000'))
|
|
||||||
|
|
||||||
def test_arithmetic_functions():
|
|
||||||
import operator
|
|
||||||
ops = [(operator.add, fadd), (operator.sub, fsub), (operator.mul, fmul),
|
|
||||||
(operator.div, fdiv)]
|
|
||||||
a = mpf(0.27)
|
|
||||||
b = mpf(1.13)
|
|
||||||
c = mpc(0.51+2.16j)
|
|
||||||
d = mpc(1.08-0.99j)
|
|
||||||
for x in [a,b,c,d]:
|
|
||||||
for y in [a,b,c,d]:
|
|
||||||
for op, fop in ops:
|
|
||||||
if fop is not fdiv:
|
|
||||||
mp.prec = 200
|
|
||||||
z0 = op(x,y)
|
|
||||||
mp.prec = 60
|
|
||||||
z1 = op(x,y)
|
|
||||||
mp.prec = 53
|
|
||||||
z2 = op(x,y)
|
|
||||||
assert fop(x, y, prec=60) == z1
|
|
||||||
assert fop(x, y) == z2
|
|
||||||
if fop is not fdiv:
|
|
||||||
assert fop(x, y, prec=inf) == z0
|
|
||||||
assert fop(x, y, dps=inf) == z0
|
|
||||||
assert fop(x, y, exact=True) == z0
|
|
||||||
assert fneg(fneg(z1, exact=True), prec=inf) == z1
|
|
||||||
assert fneg(z1) == -(+z1)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
@ -1,172 +0,0 @@
|
||||||
"""
|
|
||||||
Test bit-level integer and mpf operations
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mpmath import *
|
|
||||||
from mpmath.libmp import *
|
|
||||||
|
|
||||||
def test_bitcount():
|
|
||||||
assert bitcount(0) == 0
|
|
||||||
assert bitcount(1) == 1
|
|
||||||
assert bitcount(7) == 3
|
|
||||||
assert bitcount(8) == 4
|
|
||||||
assert bitcount(2**100) == 101
|
|
||||||
assert bitcount(2**100-1) == 100
|
|
||||||
|
|
||||||
def test_trailing():
|
|
||||||
assert trailing(0) == 0
|
|
||||||
assert trailing(1) == 0
|
|
||||||
assert trailing(2) == 1
|
|
||||||
assert trailing(7) == 0
|
|
||||||
assert trailing(8) == 3
|
|
||||||
assert trailing(2**100) == 100
|
|
||||||
assert trailing(2**100-1) == 0
|
|
||||||
|
|
||||||
def test_round_down():
|
|
||||||
assert from_man_exp(0, -4, 4, round_down)[:3] == (0, 0, 0)
|
|
||||||
assert from_man_exp(0xf0, -4, 4, round_down)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xf1, -4, 4, round_down)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xff, -4, 4, round_down)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(-0xf0, -4, 4, round_down)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xf1, -4, 4, round_down)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xff, -4, 4, round_down)[:3] == (1, 15, 0)
|
|
||||||
|
|
||||||
def test_round_up():
|
|
||||||
assert from_man_exp(0, -4, 4, round_up)[:3] == (0, 0, 0)
|
|
||||||
assert from_man_exp(0xf0, -4, 4, round_up)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xf1, -4, 4, round_up)[:3] == (0, 1, 4)
|
|
||||||
assert from_man_exp(0xff, -4, 4, round_up)[:3] == (0, 1, 4)
|
|
||||||
assert from_man_exp(-0xf0, -4, 4, round_up)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xf1, -4, 4, round_up)[:3] == (1, 1, 4)
|
|
||||||
assert from_man_exp(-0xff, -4, 4, round_up)[:3] == (1, 1, 4)
|
|
||||||
|
|
||||||
def test_round_floor():
|
|
||||||
assert from_man_exp(0, -4, 4, round_floor)[:3] == (0, 0, 0)
|
|
||||||
assert from_man_exp(0xf0, -4, 4, round_floor)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xf1, -4, 4, round_floor)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xff, -4, 4, round_floor)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(-0xf0, -4, 4, round_floor)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xf1, -4, 4, round_floor)[:3] == (1, 1, 4)
|
|
||||||
assert from_man_exp(-0xff, -4, 4, round_floor)[:3] == (1, 1, 4)
|
|
||||||
|
|
||||||
def test_round_ceiling():
|
|
||||||
assert from_man_exp(0, -4, 4, round_ceiling)[:3] == (0, 0, 0)
|
|
||||||
assert from_man_exp(0xf0, -4, 4, round_ceiling)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xf1, -4, 4, round_ceiling)[:3] == (0, 1, 4)
|
|
||||||
assert from_man_exp(0xff, -4, 4, round_ceiling)[:3] == (0, 1, 4)
|
|
||||||
assert from_man_exp(-0xf0, -4, 4, round_ceiling)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xf1, -4, 4, round_ceiling)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xff, -4, 4, round_ceiling)[:3] == (1, 15, 0)
|
|
||||||
|
|
||||||
def test_round_nearest():
|
|
||||||
assert from_man_exp(0, -4, 4, round_nearest)[:3] == (0, 0, 0)
|
|
||||||
assert from_man_exp(0xf0, -4, 4, round_nearest)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xf7, -4, 4, round_nearest)[:3] == (0, 15, 0)
|
|
||||||
assert from_man_exp(0xf8, -4, 4, round_nearest)[:3] == (0, 1, 4) # 1111.1000 -> 10000.0
|
|
||||||
assert from_man_exp(0xf9, -4, 4, round_nearest)[:3] == (0, 1, 4) # 1111.1001 -> 10000.0
|
|
||||||
assert from_man_exp(0xe8, -4, 4, round_nearest)[:3] == (0, 7, 1) # 1110.1000 -> 1110.0
|
|
||||||
assert from_man_exp(0xe9, -4, 4, round_nearest)[:3] == (0, 15, 0) # 1110.1001 -> 1111.0
|
|
||||||
assert from_man_exp(-0xf0, -4, 4, round_nearest)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xf7, -4, 4, round_nearest)[:3] == (1, 15, 0)
|
|
||||||
assert from_man_exp(-0xf8, -4, 4, round_nearest)[:3] == (1, 1, 4)
|
|
||||||
assert from_man_exp(-0xf9, -4, 4, round_nearest)[:3] == (1, 1, 4)
|
|
||||||
assert from_man_exp(-0xe8, -4, 4, round_nearest)[:3] == (1, 7, 1)
|
|
||||||
assert from_man_exp(-0xe9, -4, 4, round_nearest)[:3] == (1, 15, 0)
|
|
||||||
|
|
||||||
def test_rounding_bugs():
|
|
||||||
# 1 less than power-of-two cases
|
|
||||||
assert from_man_exp(72057594037927935, -56, 53, round_up) == (0, 1, 0, 1)
|
|
||||||
assert from_man_exp(73786976294838205979l, -65, 53, round_nearest) == (0, 1, 1, 1)
|
|
||||||
assert from_man_exp(31, 0, 4, round_up) == (0, 1, 5, 1)
|
|
||||||
assert from_man_exp(-31, 0, 4, round_floor) == (1, 1, 5, 1)
|
|
||||||
assert from_man_exp(255, 0, 7, round_up) == (0, 1, 8, 1)
|
|
||||||
assert from_man_exp(-255, 0, 7, round_floor) == (1, 1, 8, 1)
|
|
||||||
|
|
||||||
def test_rounding_issue160():
|
|
||||||
a = from_man_exp(9867,-100)
|
|
||||||
b = from_man_exp(9867,-200)
|
|
||||||
c = from_man_exp(-1,0)
|
|
||||||
z = (1, 1023, -10, 10)
|
|
||||||
assert mpf_add(a, c, 10, 'd') == z
|
|
||||||
assert mpf_add(b, c, 10, 'd') == z
|
|
||||||
assert mpf_add(c, a, 10, 'd') == z
|
|
||||||
assert mpf_add(c, b, 10, 'd') == z
|
|
||||||
|
|
||||||
def test_perturb():
|
|
||||||
a = fone
|
|
||||||
b = from_float(0.99999999999999989)
|
|
||||||
c = from_float(1.0000000000000002)
|
|
||||||
assert mpf_perturb(a, 0, 53, round_nearest) == a
|
|
||||||
assert mpf_perturb(a, 1, 53, round_nearest) == a
|
|
||||||
assert mpf_perturb(a, 0, 53, round_up) == c
|
|
||||||
assert mpf_perturb(a, 0, 53, round_ceiling) == c
|
|
||||||
assert mpf_perturb(a, 0, 53, round_down) == a
|
|
||||||
assert mpf_perturb(a, 0, 53, round_floor) == a
|
|
||||||
assert mpf_perturb(a, 1, 53, round_up) == a
|
|
||||||
assert mpf_perturb(a, 1, 53, round_ceiling) == a
|
|
||||||
assert mpf_perturb(a, 1, 53, round_down) == b
|
|
||||||
assert mpf_perturb(a, 1, 53, round_floor) == b
|
|
||||||
a = mpf_neg(a)
|
|
||||||
b = mpf_neg(b)
|
|
||||||
c = mpf_neg(c)
|
|
||||||
assert mpf_perturb(a, 0, 53, round_nearest) == a
|
|
||||||
assert mpf_perturb(a, 1, 53, round_nearest) == a
|
|
||||||
assert mpf_perturb(a, 0, 53, round_up) == a
|
|
||||||
assert mpf_perturb(a, 0, 53, round_floor) == a
|
|
||||||
assert mpf_perturb(a, 0, 53, round_down) == b
|
|
||||||
assert mpf_perturb(a, 0, 53, round_ceiling) == b
|
|
||||||
assert mpf_perturb(a, 1, 53, round_up) == c
|
|
||||||
assert mpf_perturb(a, 1, 53, round_floor) == c
|
|
||||||
assert mpf_perturb(a, 1, 53, round_down) == a
|
|
||||||
assert mpf_perturb(a, 1, 53, round_ceiling) == a
|
|
||||||
|
|
||||||
def test_add_exact():
|
|
||||||
ff = from_float
|
|
||||||
assert mpf_add(ff(3.0), ff(2.5)) == ff(5.5)
|
|
||||||
assert mpf_add(ff(3.0), ff(-2.5)) == ff(0.5)
|
|
||||||
assert mpf_add(ff(-3.0), ff(2.5)) == ff(-0.5)
|
|
||||||
assert mpf_add(ff(-3.0), ff(-2.5)) == ff(-5.5)
|
|
||||||
assert mpf_sub(mpf_add(fone, ff(1e-100)), fone) == ff(1e-100)
|
|
||||||
assert mpf_sub(mpf_add(ff(1e-100), fone), fone) == ff(1e-100)
|
|
||||||
assert mpf_sub(mpf_add(fone, ff(-1e-100)), fone) == ff(-1e-100)
|
|
||||||
assert mpf_sub(mpf_add(ff(-1e-100), fone), fone) == ff(-1e-100)
|
|
||||||
assert mpf_add(fone, fzero) == fone
|
|
||||||
assert mpf_add(fzero, fone) == fone
|
|
||||||
assert mpf_add(fzero, fzero) == fzero
|
|
||||||
|
|
||||||
def test_long_exponent_shifts():
|
|
||||||
mp.dps = 15
|
|
||||||
# Check for possible bugs due to exponent arithmetic overflow
|
|
||||||
# in a C implementation
|
|
||||||
x = mpf(1)
|
|
||||||
for p in [32, 64]:
|
|
||||||
a = ldexp(1,2**(p-1))
|
|
||||||
b = ldexp(1,2**p)
|
|
||||||
c = ldexp(1,2**(p+1))
|
|
||||||
d = ldexp(1,-2**(p-1))
|
|
||||||
e = ldexp(1,-2**p)
|
|
||||||
f = ldexp(1,-2**(p+1))
|
|
||||||
assert (x+a) == a
|
|
||||||
assert (x+b) == b
|
|
||||||
assert (x+c) == c
|
|
||||||
assert (x+d) == x
|
|
||||||
assert (x+e) == x
|
|
||||||
assert (x+f) == x
|
|
||||||
assert (a+x) == a
|
|
||||||
assert (b+x) == b
|
|
||||||
assert (c+x) == c
|
|
||||||
assert (d+x) == x
|
|
||||||
assert (e+x) == x
|
|
||||||
assert (f+x) == x
|
|
||||||
assert (x-a) == -a
|
|
||||||
assert (x-b) == -b
|
|
||||||
assert (x-c) == -c
|
|
||||||
assert (x-d) == x
|
|
||||||
assert (x-e) == x
|
|
||||||
assert (x-f) == x
|
|
||||||
assert (a-x) == a
|
|
||||||
assert (b-x) == b
|
|
||||||
assert (c-x) == c
|
|
||||||
assert (d-x) == -x
|
|
||||||
assert (e-x) == -x
|
|
||||||
assert (f-x) == -x
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_approximation():
|
|
||||||
mp.dps = 15
|
|
||||||
f = lambda x: cos(2-2*x)/x
|
|
||||||
p, err = chebyfit(f, [2, 4], 8, error=True)
|
|
||||||
assert err < 1e-5
|
|
||||||
for i in range(10):
|
|
||||||
x = 2 + i/5.
|
|
||||||
assert abs(polyval(p, x) - f(x)) < err
|
|
||||||
|
|
||||||
def test_limits():
|
|
||||||
mp.dps = 15
|
|
||||||
assert limit(lambda x: (x-sin(x))/x**3, 0).ae(mpf(1)/6)
|
|
||||||
assert limit(lambda n: (1+1/n)**n, inf).ae(e)
|
|
||||||
|
|
||||||
def test_polyval():
|
|
||||||
assert polyval([], 3) == 0
|
|
||||||
assert polyval([0], 3) == 0
|
|
||||||
assert polyval([5], 3) == 5
|
|
||||||
# 4x^3 - 2x + 5
|
|
||||||
p = [4, 0, -2, 5]
|
|
||||||
assert polyval(p,4) == 253
|
|
||||||
assert polyval(p,4,derivative=True) == (253, 190)
|
|
||||||
|
|
||||||
def test_polyroots():
|
|
||||||
p = polyroots([1,-4])
|
|
||||||
assert p[0].ae(4)
|
|
||||||
p, q = polyroots([1,2,3])
|
|
||||||
assert p.ae(-1 - sqrt(2)*j)
|
|
||||||
assert q.ae(-1 + sqrt(2)*j)
|
|
||||||
#this is not a real test, it only tests a specific case
|
|
||||||
assert polyroots([1]) == []
|
|
||||||
try:
|
|
||||||
polyroots([0])
|
|
||||||
assert False
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_pade():
|
|
||||||
one = mpf(1)
|
|
||||||
mp.dps = 20
|
|
||||||
N = 10
|
|
||||||
a = [one]
|
|
||||||
k = 1
|
|
||||||
for i in range(1, N+1):
|
|
||||||
k *= i
|
|
||||||
a.append(one/k)
|
|
||||||
p, q = pade(a, N//2, N//2)
|
|
||||||
for x in arange(0, 1, 0.1):
|
|
||||||
r = polyval(p[::-1], x)/polyval(q[::-1], x)
|
|
||||||
assert(r.ae(exp(x), 1.0e-10))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_fourier():
|
|
||||||
mp.dps = 15
|
|
||||||
c, s = fourier(lambda x: x+1, [-1, 2], 2)
|
|
||||||
#plot([lambda x: x+1, lambda x: fourierval((c, s), [-1, 2], x)], [-1, 2])
|
|
||||||
assert c[0].ae(1.5)
|
|
||||||
assert c[1].ae(-3*sqrt(3)/(2*pi))
|
|
||||||
assert c[2].ae(3*sqrt(3)/(4*pi))
|
|
||||||
assert s[0] == 0
|
|
||||||
assert s[1].ae(3/(2*pi))
|
|
||||||
assert s[2].ae(3/(4*pi))
|
|
||||||
assert fourierval((c, s), [-1, 2], 1).ae(1.9134966715663442)
|
|
||||||
|
|
||||||
def test_differint():
|
|
||||||
mp.dps = 15
|
|
||||||
assert differint(lambda t: t, 2, -0.5).ae(8*sqrt(2/pi)/3)
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
from random import seed, randint, random
|
|
||||||
import math
|
|
||||||
|
|
||||||
# Test compatibility with Python floats, which are
|
|
||||||
# IEEE doubles (53-bit)
|
|
||||||
|
|
||||||
N = 5000
|
|
||||||
seed(1)
|
|
||||||
|
|
||||||
# Choosing exponents between roughly -140, 140 ensures that
|
|
||||||
# the Python floats don't overflow or underflow
|
|
||||||
xs = [(random()-1) * 10**randint(-140, 140) for x in range(N)]
|
|
||||||
ys = [(random()-1) * 10**randint(-140, 140) for x in range(N)]
|
|
||||||
|
|
||||||
# include some equal values
|
|
||||||
ys[int(N*0.8):] = xs[int(N*0.8):]
|
|
||||||
|
|
||||||
# Detect whether Python is compiled to use 80-bit floating-point
|
|
||||||
# instructions, in which case the double compatibility test breaks
|
|
||||||
uses_x87 = -4.1974624032366689e+117 / -8.4657370748010221e-47 \
|
|
||||||
== 4.9581771393902231e+163
|
|
||||||
|
|
||||||
def test_double_compatibility():
|
|
||||||
mp.prec = 53
|
|
||||||
for x, y in zip(xs, ys):
|
|
||||||
mpx = mpf(x)
|
|
||||||
mpy = mpf(y)
|
|
||||||
assert mpf(x) == x
|
|
||||||
assert (mpx < mpy) == (x < y)
|
|
||||||
assert (mpx > mpy) == (x > y)
|
|
||||||
assert (mpx == mpy) == (x == y)
|
|
||||||
assert (mpx != mpy) == (x != y)
|
|
||||||
assert (mpx <= mpy) == (x <= y)
|
|
||||||
assert (mpx >= mpy) == (x >= y)
|
|
||||||
assert mpx == mpx
|
|
||||||
if uses_x87:
|
|
||||||
mp.prec = 64
|
|
||||||
a = mpx + mpy
|
|
||||||
b = mpx * mpy
|
|
||||||
c = mpx / mpy
|
|
||||||
d = mpx % mpy
|
|
||||||
mp.prec = 53
|
|
||||||
assert +a == x + y
|
|
||||||
assert +b == x * y
|
|
||||||
assert +c == x / y
|
|
||||||
assert +d == x % y
|
|
||||||
else:
|
|
||||||
assert mpx + mpy == x + y
|
|
||||||
assert mpx * mpy == x * y
|
|
||||||
assert mpx / mpy == x / y
|
|
||||||
assert mpx % mpy == x % y
|
|
||||||
assert abs(mpx) == abs(x)
|
|
||||||
assert mpf(repr(x)) == x
|
|
||||||
assert ceil(mpx) == math.ceil(x)
|
|
||||||
assert floor(mpx) == math.floor(x)
|
|
||||||
|
|
||||||
def test_sqrt():
|
|
||||||
# this fails quite often. it appers to be float
|
|
||||||
# that rounds the wrong way, not mpf
|
|
||||||
fail = 0
|
|
||||||
mp.prec = 53
|
|
||||||
for x in xs:
|
|
||||||
x = abs(x)
|
|
||||||
mp.prec = 100
|
|
||||||
mp_high = mpf(x)**0.5
|
|
||||||
mp.prec = 53
|
|
||||||
mp_low = mpf(x)**0.5
|
|
||||||
fp = x**0.5
|
|
||||||
assert abs(mp_low-mp_high) <= abs(fp-mp_high)
|
|
||||||
fail += mp_low != fp
|
|
||||||
assert fail < N/10
|
|
||||||
|
|
||||||
def test_bugs():
|
|
||||||
# particular bugs
|
|
||||||
assert mpf(4.4408920985006262E-16) < mpf(1.7763568394002505E-15)
|
|
||||||
assert mpf(-4.4408920985006262E-16) > mpf(-1.7763568394002505E-15)
|
|
||||||
|
|
@ -1,186 +0,0 @@
|
||||||
import random
|
|
||||||
from mpmath import *
|
|
||||||
from mpmath.libmp import *
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_string():
|
|
||||||
"""
|
|
||||||
Test basic string conversion
|
|
||||||
"""
|
|
||||||
mp.dps = 15
|
|
||||||
assert mpf('3') == mpf('3.0') == mpf('0003.') == mpf('0.03e2') == mpf(3.0)
|
|
||||||
assert mpf('30') == mpf('30.0') == mpf('00030.') == mpf(30.0)
|
|
||||||
for i in range(10):
|
|
||||||
for j in range(10):
|
|
||||||
assert mpf('%ie%i' % (i,j)) == i * 10**j
|
|
||||||
assert str(mpf('25000.0')) == '25000.0'
|
|
||||||
assert str(mpf('2500.0')) == '2500.0'
|
|
||||||
assert str(mpf('250.0')) == '250.0'
|
|
||||||
assert str(mpf('25.0')) == '25.0'
|
|
||||||
assert str(mpf('2.5')) == '2.5'
|
|
||||||
assert str(mpf('0.25')) == '0.25'
|
|
||||||
assert str(mpf('0.025')) == '0.025'
|
|
||||||
assert str(mpf('0.0025')) == '0.0025'
|
|
||||||
assert str(mpf('0.00025')) == '0.00025'
|
|
||||||
assert str(mpf('0.000025')) == '2.5e-5'
|
|
||||||
assert str(mpf(0)) == '0.0'
|
|
||||||
assert str(mpf('2.5e1000000000000000000000')) == '2.5e+1000000000000000000000'
|
|
||||||
assert str(mpf('2.6e-1000000000000000000000')) == '2.6e-1000000000000000000000'
|
|
||||||
assert str(mpf(1.23402834e-15)) == '1.23402834e-15'
|
|
||||||
assert str(mpf(-1.23402834e-15)) == '-1.23402834e-15'
|
|
||||||
assert str(mpf(-1.2344e-15)) == '-1.2344e-15'
|
|
||||||
assert repr(mpf(-1.2344e-15)) == "mpf('-1.2343999999999999e-15')"
|
|
||||||
|
|
||||||
def test_pretty():
|
|
||||||
mp.pretty = True
|
|
||||||
assert repr(mpf(2.5)) == '2.5'
|
|
||||||
assert repr(mpc(2.5,3.5)) == '(2.5 + 3.5j)'
|
|
||||||
assert repr(mpi(2.5,3.5)) == '[2.5, 3.5]'
|
|
||||||
mp.pretty = False
|
|
||||||
|
|
||||||
def test_str_whitespace():
|
|
||||||
assert mpf('1.26 ') == 1.26
|
|
||||||
|
|
||||||
def test_unicode():
|
|
||||||
mp.dps = 15
|
|
||||||
assert mpf(u'2.76') == 2.76
|
|
||||||
assert mpf(u'inf') == inf
|
|
||||||
|
|
||||||
def test_str_format():
|
|
||||||
assert to_str(from_float(0.1),15,strip_zeros=False) == '0.100000000000000'
|
|
||||||
assert to_str(from_float(0.0),15,show_zero_exponent=True) == '0.0e+0'
|
|
||||||
assert to_str(from_float(0.0),0,show_zero_exponent=True) == '.0e+0'
|
|
||||||
assert to_str(from_float(0.0),0,show_zero_exponent=False) == '.0'
|
|
||||||
assert to_str(from_float(0.0),1,show_zero_exponent=True) == '0.0e+0'
|
|
||||||
assert to_str(from_float(0.0),1,show_zero_exponent=False) == '0.0'
|
|
||||||
assert to_str(from_float(1.23),3,show_zero_exponent=True) == '1.23e+0'
|
|
||||||
assert to_str(from_float(1.23456789000000e-2),15,strip_zeros=False,min_fixed=0,max_fixed=0) == '1.23456789000000e-2'
|
|
||||||
assert to_str(from_float(1.23456789000000e+2),15,strip_zeros=False,min_fixed=0,max_fixed=0) == '1.23456789000000e+2'
|
|
||||||
assert to_str(from_float(2.1287e14), 15, max_fixed=1000) == '212870000000000.0'
|
|
||||||
assert to_str(from_float(2.1287e15), 15, max_fixed=1000) == '2128700000000000.0'
|
|
||||||
assert to_str(from_float(2.1287e16), 15, max_fixed=1000) == '21287000000000000.0'
|
|
||||||
assert to_str(from_float(2.1287e30), 15, max_fixed=1000) == '2128700000000000000000000000000.0'
|
|
||||||
|
|
||||||
def test_tight_string_conversion():
|
|
||||||
mp.dps = 15
|
|
||||||
# In an old version, '0.5' wasn't recognized as representing
|
|
||||||
# an exact binary number and was erroneously rounded up or down
|
|
||||||
assert from_str('0.5', 10, round_floor) == fhalf
|
|
||||||
assert from_str('0.5', 10, round_ceiling) == fhalf
|
|
||||||
|
|
||||||
def test_eval_repr_invariant():
|
|
||||||
"""Test that eval(repr(x)) == x"""
|
|
||||||
random.seed(123)
|
|
||||||
for dps in [10, 15, 20, 50, 100]:
|
|
||||||
mp.dps = dps
|
|
||||||
for i in xrange(1000):
|
|
||||||
a = mpf(random.random())**0.5 * 10**random.randint(-100, 100)
|
|
||||||
assert eval(repr(a)) == a
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_str_bugs():
|
|
||||||
mp.dps = 15
|
|
||||||
# Decimal rounding used to give the wrong exponent in some cases
|
|
||||||
assert str(mpf('1e600')) == '1.0e+600'
|
|
||||||
assert str(mpf('1e10000')) == '1.0e+10000'
|
|
||||||
|
|
||||||
def test_str_prec0():
|
|
||||||
assert to_str(from_float(1.234), 0) == '.0e+0'
|
|
||||||
assert to_str(from_float(1e-15), 0) == '.0e-15'
|
|
||||||
assert to_str(from_float(1e+15), 0) == '.0e+15'
|
|
||||||
assert to_str(from_float(-1e-15), 0) == '-.0e-15'
|
|
||||||
assert to_str(from_float(-1e+15), 0) == '-.0e+15'
|
|
||||||
|
|
||||||
def test_convert_rational():
|
|
||||||
mp.dps = 15
|
|
||||||
assert from_rational(30, 5, 53, round_nearest) == (0, 3, 1, 2)
|
|
||||||
assert from_rational(-7, 4, 53, round_nearest) == (1, 7, -2, 3)
|
|
||||||
assert to_rational((0, 1, -1, 1)) == (1, 2)
|
|
||||||
|
|
||||||
def test_custom_class():
|
|
||||||
class mympf:
|
|
||||||
@property
|
|
||||||
def _mpf_(self):
|
|
||||||
return mpf(3.5)._mpf_
|
|
||||||
class mympc:
|
|
||||||
@property
|
|
||||||
def _mpc_(self):
|
|
||||||
return mpf(3.5)._mpf_, mpf(2.5)._mpf_
|
|
||||||
assert mpf(2) + mympf() == 5.5
|
|
||||||
assert mympf() + mpf(2) == 5.5
|
|
||||||
assert mpf(mympf()) == 3.5
|
|
||||||
assert mympc() + mpc(2) == mpc(5.5, 2.5)
|
|
||||||
assert mpc(2) + mympc() == mpc(5.5, 2.5)
|
|
||||||
assert mpc(mympc()) == (3.5+2.5j)
|
|
||||||
|
|
||||||
def test_conversion_methods():
|
|
||||||
class SomethingRandom:
|
|
||||||
pass
|
|
||||||
class SomethingReal:
|
|
||||||
def _mpmath_(self, prec, rounding):
|
|
||||||
return mp.make_mpf(from_str('1.3', prec, rounding))
|
|
||||||
class SomethingComplex:
|
|
||||||
def _mpmath_(self, prec, rounding):
|
|
||||||
return mp.make_mpc((from_str('1.3', prec, rounding), \
|
|
||||||
from_str('1.7', prec, rounding)))
|
|
||||||
x = mpf(3)
|
|
||||||
z = mpc(3)
|
|
||||||
a = SomethingRandom()
|
|
||||||
y = SomethingReal()
|
|
||||||
w = SomethingComplex()
|
|
||||||
for d in [15, 45]:
|
|
||||||
mp.dps = d
|
|
||||||
assert (x+y).ae(mpf('4.3'))
|
|
||||||
assert (y+x).ae(mpf('4.3'))
|
|
||||||
assert (x+w).ae(mpc('4.3', '1.7'))
|
|
||||||
assert (w+x).ae(mpc('4.3', '1.7'))
|
|
||||||
assert (z+y).ae(mpc('4.3'))
|
|
||||||
assert (y+z).ae(mpc('4.3'))
|
|
||||||
assert (z+w).ae(mpc('4.3', '1.7'))
|
|
||||||
assert (w+z).ae(mpc('4.3', '1.7'))
|
|
||||||
x-y; y-x; x-w; w-x; z-y; y-z; z-w; w-z
|
|
||||||
x*y; y*x; x*w; w*x; z*y; y*z; z*w; w*z
|
|
||||||
x/y; y/x; x/w; w/x; z/y; y/z; z/w; w/z
|
|
||||||
x**y; y**x; x**w; w**x; z**y; y**z; z**w; w**z
|
|
||||||
x==y; y==x; x==w; w==x; z==y; y==z; z==w; w==z
|
|
||||||
mp.dps = 15
|
|
||||||
assert x.__add__(a) is NotImplemented
|
|
||||||
assert x.__radd__(a) is NotImplemented
|
|
||||||
assert x.__lt__(a) is NotImplemented
|
|
||||||
assert x.__gt__(a) is NotImplemented
|
|
||||||
assert x.__le__(a) is NotImplemented
|
|
||||||
assert x.__ge__(a) is NotImplemented
|
|
||||||
assert x.__eq__(a) is NotImplemented
|
|
||||||
assert x.__ne__(a) is NotImplemented
|
|
||||||
# implementation detail
|
|
||||||
if hasattr(x, "__cmp__"):
|
|
||||||
assert x.__cmp__(a) is NotImplemented
|
|
||||||
assert x.__sub__(a) is NotImplemented
|
|
||||||
assert x.__rsub__(a) is NotImplemented
|
|
||||||
assert x.__mul__(a) is NotImplemented
|
|
||||||
assert x.__rmul__(a) is NotImplemented
|
|
||||||
assert x.__div__(a) is NotImplemented
|
|
||||||
assert x.__rdiv__(a) is NotImplemented
|
|
||||||
assert x.__mod__(a) is NotImplemented
|
|
||||||
assert x.__rmod__(a) is NotImplemented
|
|
||||||
assert x.__pow__(a) is NotImplemented
|
|
||||||
assert x.__rpow__(a) is NotImplemented
|
|
||||||
assert z.__add__(a) is NotImplemented
|
|
||||||
assert z.__radd__(a) is NotImplemented
|
|
||||||
assert z.__eq__(a) is NotImplemented
|
|
||||||
assert z.__ne__(a) is NotImplemented
|
|
||||||
assert z.__sub__(a) is NotImplemented
|
|
||||||
assert z.__rsub__(a) is NotImplemented
|
|
||||||
assert z.__mul__(a) is NotImplemented
|
|
||||||
assert z.__rmul__(a) is NotImplemented
|
|
||||||
assert z.__div__(a) is NotImplemented
|
|
||||||
assert z.__rdiv__(a) is NotImplemented
|
|
||||||
assert z.__pow__(a) is NotImplemented
|
|
||||||
assert z.__rpow__(a) is NotImplemented
|
|
||||||
|
|
||||||
def test_mpmathify():
|
|
||||||
assert mpmathify('1/2') == 0.5
|
|
||||||
assert mpmathify('(1.0+1.0j)') == mpc(1, 1)
|
|
||||||
assert mpmathify('(1.2e-10 - 3.4e5j)') == mpc('1.2e-10', '-3.4e5')
|
|
||||||
assert mpmathify('1j') == mpc(1j)
|
|
||||||
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_diff():
|
|
||||||
assert diff(log, 2.0, n=0).ae(log(2))
|
|
||||||
assert diff(cos, 1.0).ae(-sin(1))
|
|
||||||
assert diff(abs, 0.0) == 0
|
|
||||||
assert diff(abs, 0.0, direction=1) == 1
|
|
||||||
assert diff(abs, 0.0, direction=-1) == -1
|
|
||||||
assert diff(exp, 1.0).ae(e)
|
|
||||||
assert diff(exp, 1.0, n=5).ae(e)
|
|
||||||
assert diff(exp, 2.0, n=5, direction=3*j).ae(e**2)
|
|
||||||
assert diff(lambda x: x**2, 3.0, method='quad').ae(6)
|
|
||||||
assert diff(lambda x: 3+x**5, 3.0, n=2, method='quad').ae(540)
|
|
||||||
assert diff(lambda x: 3+x**5, 3.0, n=2, method='step').ae(540)
|
|
||||||
assert diffun(sin)(2).ae(cos(2))
|
|
||||||
assert diffun(sin, n=2)(2).ae(-sin(2))
|
|
||||||
|
|
||||||
def test_taylor():
|
|
||||||
# Easy to test since the coefficients are exact in floating-point
|
|
||||||
assert taylor(sqrt, 1, 4) == [1, 0.5, -0.125, 0.0625, -0.0390625]
|
|
||||||
|
|
@ -1,143 +0,0 @@
|
||||||
from mpmath.libmp import *
|
|
||||||
from mpmath import mpf, mp
|
|
||||||
|
|
||||||
from random import randint, choice, seed
|
|
||||||
|
|
||||||
all_modes = [round_floor, round_ceiling, round_down, round_up, round_nearest]
|
|
||||||
|
|
||||||
fb = from_bstr
|
|
||||||
fi = from_int
|
|
||||||
ff = from_float
|
|
||||||
|
|
||||||
|
|
||||||
def test_div_1_3():
|
|
||||||
a = fi(1)
|
|
||||||
b = fi(3)
|
|
||||||
c = fi(-1)
|
|
||||||
|
|
||||||
# floor rounds down, ceiling rounds up
|
|
||||||
assert mpf_div(a, b, 7, round_floor) == fb('0.01010101')
|
|
||||||
assert mpf_div(a, b, 7, round_ceiling) == fb('0.01010110')
|
|
||||||
assert mpf_div(a, b, 7, round_down) == fb('0.01010101')
|
|
||||||
assert mpf_div(a, b, 7, round_up) == fb('0.01010110')
|
|
||||||
assert mpf_div(a, b, 7, round_nearest) == fb('0.01010101')
|
|
||||||
|
|
||||||
# floor rounds up, ceiling rounds down
|
|
||||||
assert mpf_div(c, b, 7, round_floor) == fb('-0.01010110')
|
|
||||||
assert mpf_div(c, b, 7, round_ceiling) == fb('-0.01010101')
|
|
||||||
assert mpf_div(c, b, 7, round_down) == fb('-0.01010101')
|
|
||||||
assert mpf_div(c, b, 7, round_up) == fb('-0.01010110')
|
|
||||||
assert mpf_div(c, b, 7, round_nearest) == fb('-0.01010101')
|
|
||||||
|
|
||||||
def test_mpf_divi_1_3():
|
|
||||||
a = 1
|
|
||||||
b = fi(3)
|
|
||||||
c = -1
|
|
||||||
assert mpf_rdiv_int(a, b, 7, round_floor) == fb('0.01010101')
|
|
||||||
assert mpf_rdiv_int(a, b, 7, round_ceiling) == fb('0.01010110')
|
|
||||||
assert mpf_rdiv_int(a, b, 7, round_down) == fb('0.01010101')
|
|
||||||
assert mpf_rdiv_int(a, b, 7, round_up) == fb('0.01010110')
|
|
||||||
assert mpf_rdiv_int(a, b, 7, round_nearest) == fb('0.01010101')
|
|
||||||
assert mpf_rdiv_int(c, b, 7, round_floor) == fb('-0.01010110')
|
|
||||||
assert mpf_rdiv_int(c, b, 7, round_ceiling) == fb('-0.01010101')
|
|
||||||
assert mpf_rdiv_int(c, b, 7, round_down) == fb('-0.01010101')
|
|
||||||
assert mpf_rdiv_int(c, b, 7, round_up) == fb('-0.01010110')
|
|
||||||
assert mpf_rdiv_int(c, b, 7, round_nearest) == fb('-0.01010101')
|
|
||||||
|
|
||||||
|
|
||||||
def test_div_300():
|
|
||||||
|
|
||||||
q = fi(1000000)
|
|
||||||
a = fi(300499999) # a/q is a little less than a half-integer
|
|
||||||
b = fi(300500000) # b/q exactly a half-integer
|
|
||||||
c = fi(300500001) # c/q is a little more than a half-integer
|
|
||||||
|
|
||||||
# Check nearest integer rounding (prec=9 as 2**8 < 300 < 2**9)
|
|
||||||
|
|
||||||
assert mpf_div(a, q, 9, round_down) == fi(300)
|
|
||||||
assert mpf_div(b, q, 9, round_down) == fi(300)
|
|
||||||
assert mpf_div(c, q, 9, round_down) == fi(300)
|
|
||||||
assert mpf_div(a, q, 9, round_up) == fi(301)
|
|
||||||
assert mpf_div(b, q, 9, round_up) == fi(301)
|
|
||||||
assert mpf_div(c, q, 9, round_up) == fi(301)
|
|
||||||
|
|
||||||
# Nearest even integer is down
|
|
||||||
assert mpf_div(a, q, 9, round_nearest) == fi(300)
|
|
||||||
assert mpf_div(b, q, 9, round_nearest) == fi(300)
|
|
||||||
assert mpf_div(c, q, 9, round_nearest) == fi(301)
|
|
||||||
|
|
||||||
# Nearest even integer is up
|
|
||||||
a = fi(301499999)
|
|
||||||
b = fi(301500000)
|
|
||||||
c = fi(301500001)
|
|
||||||
assert mpf_div(a, q, 9, round_nearest) == fi(301)
|
|
||||||
assert mpf_div(b, q, 9, round_nearest) == fi(302)
|
|
||||||
assert mpf_div(c, q, 9, round_nearest) == fi(302)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tight_integer_division():
|
|
||||||
# Test that integer division at tightest possible precision is exact
|
|
||||||
N = 100
|
|
||||||
seed(1)
|
|
||||||
for i in range(N):
|
|
||||||
a = choice([1, -1]) * randint(1, 1<<randint(10, 100))
|
|
||||||
b = choice([1, -1]) * randint(1, 1<<randint(10, 100))
|
|
||||||
p = a * b
|
|
||||||
width = bitcount(abs(b)) - trailing(b)
|
|
||||||
a = fi(a); b = fi(b); p = fi(p)
|
|
||||||
for mode in all_modes:
|
|
||||||
assert mpf_div(p, a, width, mode) == b
|
|
||||||
|
|
||||||
|
|
||||||
def test_epsilon_rounding():
|
|
||||||
# Verify that mpf_div uses infinite precision; this result will
|
|
||||||
# appear to be exactly 0.101 to a near-sighted algorithm
|
|
||||||
|
|
||||||
a = fb('0.101' + ('0'*200) + '1')
|
|
||||||
b = fb('1.10101')
|
|
||||||
c = mpf_mul(a, b, 250, round_floor) # exact
|
|
||||||
assert mpf_div(c, b, bitcount(a[1]), round_floor) == a # exact
|
|
||||||
|
|
||||||
assert mpf_div(c, b, 2, round_down) == fb('0.10')
|
|
||||||
assert mpf_div(c, b, 3, round_down) == fb('0.101')
|
|
||||||
assert mpf_div(c, b, 2, round_up) == fb('0.11')
|
|
||||||
assert mpf_div(c, b, 3, round_up) == fb('0.110')
|
|
||||||
assert mpf_div(c, b, 2, round_floor) == fb('0.10')
|
|
||||||
assert mpf_div(c, b, 3, round_floor) == fb('0.101')
|
|
||||||
assert mpf_div(c, b, 2, round_ceiling) == fb('0.11')
|
|
||||||
assert mpf_div(c, b, 3, round_ceiling) == fb('0.110')
|
|
||||||
|
|
||||||
# The same for negative numbers
|
|
||||||
a = fb('-0.101' + ('0'*200) + '1')
|
|
||||||
b = fb('1.10101')
|
|
||||||
c = mpf_mul(a, b, 250, round_floor)
|
|
||||||
assert mpf_div(c, b, bitcount(a[1]), round_floor) == a
|
|
||||||
|
|
||||||
assert mpf_div(c, b, 2, round_down) == fb('-0.10')
|
|
||||||
assert mpf_div(c, b, 3, round_up) == fb('-0.110')
|
|
||||||
|
|
||||||
# Floor goes up, ceiling goes down
|
|
||||||
assert mpf_div(c, b, 2, round_floor) == fb('-0.11')
|
|
||||||
assert mpf_div(c, b, 3, round_floor) == fb('-0.110')
|
|
||||||
assert mpf_div(c, b, 2, round_ceiling) == fb('-0.10')
|
|
||||||
assert mpf_div(c, b, 3, round_ceiling) == fb('-0.101')
|
|
||||||
|
|
||||||
|
|
||||||
def test_mod():
|
|
||||||
mp.dps = 15
|
|
||||||
assert mpf(234) % 1 == 0
|
|
||||||
assert mpf(-3) % 256 == 253
|
|
||||||
assert mpf(0.25) % 23490.5 == 0.25
|
|
||||||
assert mpf(0.25) % -23490.5 == -23490.25
|
|
||||||
assert mpf(-0.25) % 23490.5 == 23490.25
|
|
||||||
assert mpf(-0.25) % -23490.5 == -0.25
|
|
||||||
# Check that these cases are handled efficiently
|
|
||||||
assert mpf('1e10000000000') % 1 == 0
|
|
||||||
assert mpf('1.23e-1000000000') % 1 == mpf('1.23e-1000000000')
|
|
||||||
# test __rmod__
|
|
||||||
assert 3 % mpf('1.75') == 1.25
|
|
||||||
|
|
||||||
def test_div_negative_rnd_bug():
|
|
||||||
mp.dps = 15
|
|
||||||
assert (-3) / mpf('0.1531879017645047') == mpf('-19.583791966887116')
|
|
||||||
assert mpf('-2.6342475750861301') / mpf('0.35126216427941814') == mpf('-7.4993775104985909')
|
|
||||||
|
|
@ -1,537 +0,0 @@
|
||||||
"""
|
|
||||||
Limited tests of the elliptic functions module. A full suite of
|
|
||||||
extensive testing can be found in elliptic_torture_tests.py
|
|
||||||
|
|
||||||
Author of the first version: M.T. Taschuk
|
|
||||||
|
|
||||||
References:
|
|
||||||
|
|
||||||
[1] Abramowitz & Stegun. 'Handbook of Mathematical Functions, 9th Ed.',
|
|
||||||
(Dover duplicate of 1972 edition)
|
|
||||||
[2] Whittaker 'A Course of Modern Analysis, 4th Ed.', 1946,
|
|
||||||
Cambridge University Press
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mpmath
|
|
||||||
import random
|
|
||||||
|
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def mpc_ae(a, b, eps=eps):
|
|
||||||
res = True
|
|
||||||
res = res and a.real.ae(b.real, eps)
|
|
||||||
res = res and a.imag.ae(b.imag, eps)
|
|
||||||
return res
|
|
||||||
|
|
||||||
zero = mpf(0)
|
|
||||||
one = mpf(1)
|
|
||||||
|
|
||||||
def test_calculate_nome():
|
|
||||||
mp.dps = 100
|
|
||||||
|
|
||||||
q = calculate_nome(zero)
|
|
||||||
assert(q == zero)
|
|
||||||
|
|
||||||
mp.dps = 25
|
|
||||||
# used Mathematica's EllipticNomeQ[m]
|
|
||||||
math1 = [(mpf(1)/10, mpf('0.006584651553858370274473060')),
|
|
||||||
(mpf(2)/10, mpf('0.01394285727531826872146409')),
|
|
||||||
(mpf(3)/10, mpf('0.02227743615715350822901627')),
|
|
||||||
(mpf(4)/10, mpf('0.03188334731336317755064299')),
|
|
||||||
(mpf(5)/10, mpf('0.04321391826377224977441774')),
|
|
||||||
(mpf(6)/10, mpf('0.05702025781460967637754953')),
|
|
||||||
(mpf(7)/10, mpf('0.07468994353717944761143751')),
|
|
||||||
(mpf(8)/10, mpf('0.09927369733882489703607378')),
|
|
||||||
(mpf(9)/10, mpf('0.1401731269542615524091055')),
|
|
||||||
(mpf(9)/10, mpf('0.1401731269542615524091055'))]
|
|
||||||
|
|
||||||
for i in math1:
|
|
||||||
m = i[0]
|
|
||||||
q = calculate_nome(sqrt(m))
|
|
||||||
assert q.ae(i[1])
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_jtheta():
|
|
||||||
mp.dps = 25
|
|
||||||
|
|
||||||
z = q = zero
|
|
||||||
for n in range(1,5):
|
|
||||||
value = jtheta(n, z, q)
|
|
||||||
assert(value == (n-1)//2)
|
|
||||||
|
|
||||||
for q in [one, mpf(2)]:
|
|
||||||
for n in range(1,5):
|
|
||||||
raised = True
|
|
||||||
try:
|
|
||||||
r = jtheta(n, z, q)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raised = False
|
|
||||||
assert(raised)
|
|
||||||
|
|
||||||
z = one/10
|
|
||||||
q = one/11
|
|
||||||
|
|
||||||
# Mathematical N[EllipticTheta[1, 1/10, 1/11], 25]
|
|
||||||
res = mpf('0.1069552990104042681962096')
|
|
||||||
result = jtheta(1, z, q)
|
|
||||||
assert(result.ae(res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticTheta[2, 1/10, 1/11], 25]
|
|
||||||
res = mpf('1.101385760258855791140606')
|
|
||||||
result = jtheta(2, z, q)
|
|
||||||
assert(result.ae(res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticTheta[3, 1/10, 1/11], 25]
|
|
||||||
res = mpf('1.178319743354331061795905')
|
|
||||||
result = jtheta(3, z, q)
|
|
||||||
assert(result.ae(res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticTheta[4, 1/10, 1/11], 25]
|
|
||||||
res = mpf('0.8219318954665153577314573')
|
|
||||||
result = jtheta(4, z, q)
|
|
||||||
assert(result.ae(res))
|
|
||||||
|
|
||||||
# test for sin zeros for jtheta(1, z, q)
|
|
||||||
# test for cos zeros for jtheta(2, z, q)
|
|
||||||
z1 = pi
|
|
||||||
z2 = pi/2
|
|
||||||
for i in range(10):
|
|
||||||
qstring = str(random.random())
|
|
||||||
q = mpf(qstring)
|
|
||||||
result = jtheta(1, z1, q)
|
|
||||||
assert(result.ae(0))
|
|
||||||
result = jtheta(2, z2, q)
|
|
||||||
assert(result.ae(0))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
|
|
||||||
def test_jtheta_issue39():
|
|
||||||
# near the circle of covergence |q| = 1 the convergence slows
|
|
||||||
# down; for |q| > Q_LIM the theta functions raise ValueError
|
|
||||||
mp.dps = 30
|
|
||||||
mp.dps += 30
|
|
||||||
q = mpf(6)/10 - one/10**6 - mpf(8)/10 * j
|
|
||||||
mp.dps -= 30
|
|
||||||
# Mathematica run first
|
|
||||||
# N[EllipticTheta[3, 1, 6/10 - 10^-6 - 8/10*I], 2000]
|
|
||||||
# then it works:
|
|
||||||
# N[EllipticTheta[3, 1, 6/10 - 10^-6 - 8/10*I], 30]
|
|
||||||
res = mpf('32.0031009628901652627099524264') + \
|
|
||||||
mpf('16.6153027998236087899308935624') * j
|
|
||||||
result = jtheta(3, 1, q)
|
|
||||||
# check that for abs(q) > Q_LIM a ValueError exception is raised
|
|
||||||
mp.dps += 30
|
|
||||||
q = mpf(6)/10 - one/10**7 - mpf(8)/10 * j
|
|
||||||
mp.dps -= 30
|
|
||||||
try:
|
|
||||||
result = jtheta(3, 1, q)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
assert(False)
|
|
||||||
|
|
||||||
# bug reported in issue39
|
|
||||||
mp.dps = 100
|
|
||||||
z = (1+j)/3
|
|
||||||
q = mpf(368983957219251)/10**15 + mpf(636363636363636)/10**15 * j
|
|
||||||
# Mathematica N[EllipticTheta[1, z, q], 35]
|
|
||||||
res = mpf('2.4439389177990737589761828991467471') + \
|
|
||||||
mpf('0.5446453005688226915290954851851490') *j
|
|
||||||
mp.dps = 30
|
|
||||||
result = jtheta(1, z, q)
|
|
||||||
assert(result.ae(res))
|
|
||||||
mp.dps = 80
|
|
||||||
z = 3 + 4*j
|
|
||||||
q = 0.5 + 0.5*j
|
|
||||||
r1 = jtheta(1, z, q)
|
|
||||||
mp.dps = 15
|
|
||||||
r2 = jtheta(1, z, q)
|
|
||||||
assert r1.ae(r2)
|
|
||||||
mp.dps = 80
|
|
||||||
z = 3 + j
|
|
||||||
q1 = exp(j*3)
|
|
||||||
# longer test
|
|
||||||
# for n in range(1, 6)
|
|
||||||
for n in range(1, 2):
|
|
||||||
mp.dps = 80
|
|
||||||
q = q1*(1 - mpf(1)/10**n)
|
|
||||||
r1 = jtheta(1, z, q)
|
|
||||||
mp.dps = 15
|
|
||||||
r2 = jtheta(1, z, q)
|
|
||||||
assert r1.ae(r2)
|
|
||||||
mp.dps = 15
|
|
||||||
# issue 39 about high derivatives
|
|
||||||
assert jtheta(3, 4.5, 0.25, 9).ae(1359.04892680683)
|
|
||||||
assert jtheta(3, 4.5, 0.25, 50).ae(-6.14832772630905e+33)
|
|
||||||
mp.dps = 50
|
|
||||||
r = jtheta(3, 4.5, 0.25, 9)
|
|
||||||
assert r.ae('1359.048926806828939547859396600218966947753213803')
|
|
||||||
r = jtheta(3, 4.5, 0.25, 50)
|
|
||||||
assert r.ae('-6148327726309051673317975084654262.4119215720343656')
|
|
||||||
|
|
||||||
def test_jtheta_identities():
|
|
||||||
"""
|
|
||||||
Tests the some of the jacobi identidies found in Abramowitz,
|
|
||||||
Sec. 16.28, Pg. 576. The identities are tested to 1 part in 10^98.
|
|
||||||
"""
|
|
||||||
mp.dps = 110
|
|
||||||
eps1 = ldexp(eps, 30)
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
qstring = str(random.random())
|
|
||||||
q = mpf(qstring)
|
|
||||||
|
|
||||||
zstring = str(10*random.random())
|
|
||||||
z = mpf(zstring)
|
|
||||||
# Abramowitz 16.28.1
|
|
||||||
# v_1(z, q)**2 * v_4(0, q)**2 = v_3(z, q)**2 * v_2(0, q)**2
|
|
||||||
# - v_2(z, q)**2 * v_3(0, q)**2
|
|
||||||
term1 = (jtheta(1, z, q)**2) * (jtheta(4, zero, q)**2)
|
|
||||||
term2 = (jtheta(3, z, q)**2) * (jtheta(2, zero, q)**2)
|
|
||||||
term3 = (jtheta(2, z, q)**2) * (jtheta(3, zero, q)**2)
|
|
||||||
equality = term1 - term2 + term3
|
|
||||||
assert(equality.ae(0, eps1))
|
|
||||||
|
|
||||||
zstring = str(100*random.random())
|
|
||||||
z = mpf(zstring)
|
|
||||||
# Abramowitz 16.28.2
|
|
||||||
# v_2(z, q)**2 * v_4(0, q)**2 = v_4(z, q)**2 * v_2(0, q)**2
|
|
||||||
# - v_1(z, q)**2 * v_3(0, q)**2
|
|
||||||
term1 = (jtheta(2, z, q)**2) * (jtheta(4, zero, q)**2)
|
|
||||||
term2 = (jtheta(4, z, q)**2) * (jtheta(2, zero, q)**2)
|
|
||||||
term3 = (jtheta(1, z, q)**2) * (jtheta(3, zero, q)**2)
|
|
||||||
equality = term1 - term2 + term3
|
|
||||||
assert(equality.ae(0, eps1))
|
|
||||||
|
|
||||||
# Abramowitz 16.28.3
|
|
||||||
# v_3(z, q)**2 * v_4(0, q)**2 = v_4(z, q)**2 * v_3(0, q)**2
|
|
||||||
# - v_1(z, q)**2 * v_2(0, q)**2
|
|
||||||
term1 = (jtheta(3, z, q)**2) * (jtheta(4, zero, q)**2)
|
|
||||||
term2 = (jtheta(4, z, q)**2) * (jtheta(3, zero, q)**2)
|
|
||||||
term3 = (jtheta(1, z, q)**2) * (jtheta(2, zero, q)**2)
|
|
||||||
equality = term1 - term2 + term3
|
|
||||||
assert(equality.ae(0, eps1))
|
|
||||||
|
|
||||||
# Abramowitz 16.28.4
|
|
||||||
# v_4(z, q)**2 * v_4(0, q)**2 = v_3(z, q)**2 * v_3(0, q)**2
|
|
||||||
# - v_2(z, q)**2 * v_2(0, q)**2
|
|
||||||
term1 = (jtheta(4, z, q)**2) * (jtheta(4, zero, q)**2)
|
|
||||||
term2 = (jtheta(3, z, q)**2) * (jtheta(3, zero, q)**2)
|
|
||||||
term3 = (jtheta(2, z, q)**2) * (jtheta(2, zero, q)**2)
|
|
||||||
equality = term1 - term2 + term3
|
|
||||||
assert(equality.ae(0, eps1))
|
|
||||||
|
|
||||||
# Abramowitz 16.28.5
|
|
||||||
# v_2(0, q)**4 + v_4(0, q)**4 == v_3(0, q)**4
|
|
||||||
term1 = (jtheta(2, zero, q))**4
|
|
||||||
term2 = (jtheta(4, zero, q))**4
|
|
||||||
term3 = (jtheta(3, zero, q))**4
|
|
||||||
equality = term1 + term2 - term3
|
|
||||||
assert(equality.ae(0, eps1))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_jtheta_complex():
|
|
||||||
mp.dps = 30
|
|
||||||
z = mpf(1)/4 + j/8
|
|
||||||
q = mpf(1)/3 + j/7
|
|
||||||
# Mathematica N[EllipticTheta[1, 1/4 + I/8, 1/3 + I/7], 35]
|
|
||||||
res = mpf('0.31618034835986160705729105731678285') + \
|
|
||||||
mpf('0.07542013825835103435142515194358975') * j
|
|
||||||
r = jtheta(1, z, q)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticTheta[2, 1/4 + I/8, 1/3 + I/7], 35]
|
|
||||||
res = mpf('1.6530986428239765928634711417951828') + \
|
|
||||||
mpf('0.2015344864707197230526742145361455') * j
|
|
||||||
r = jtheta(2, z, q)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticTheta[3, 1/4 + I/8, 1/3 + I/7], 35]
|
|
||||||
res = mpf('1.6520564411784228184326012700348340') + \
|
|
||||||
mpf('0.1998129119671271328684690067401823') * j
|
|
||||||
r = jtheta(3, z, q)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticTheta[4, 1/4 + I/8, 1/3 + I/7], 35]
|
|
||||||
res = mpf('0.37619082382228348252047624089973824') - \
|
|
||||||
mpf('0.15623022130983652972686227200681074') * j
|
|
||||||
r = jtheta(4, z, q)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
|
|
||||||
# check some theta function identities
|
|
||||||
mp.dos = 100
|
|
||||||
z = mpf(1)/4 + j/8
|
|
||||||
q = mpf(1)/3 + j/7
|
|
||||||
mp.dps += 10
|
|
||||||
a = [0,0, jtheta(2, 0, q), jtheta(3, 0, q), jtheta(4, 0, q)]
|
|
||||||
t = [0, jtheta(1, z, q), jtheta(2, z, q), jtheta(3, z, q), jtheta(4, z, q)]
|
|
||||||
r = [(t[2]*a[4])**2 - (t[4]*a[2])**2 + (t[1] *a[3])**2,
|
|
||||||
(t[3]*a[4])**2 - (t[4]*a[3])**2 + (t[1] *a[2])**2,
|
|
||||||
(t[1]*a[4])**2 - (t[3]*a[2])**2 + (t[2] *a[3])**2,
|
|
||||||
(t[4]*a[4])**2 - (t[3]*a[3])**2 + (t[2] *a[2])**2,
|
|
||||||
a[2]**4 + a[4]**4 - a[3]**4]
|
|
||||||
mp.dps -= 10
|
|
||||||
for x in r:
|
|
||||||
assert(mpc_ae(x, mpc(0)))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_djtheta():
|
|
||||||
mp.dps = 30
|
|
||||||
|
|
||||||
z = one/7 + j/3
|
|
||||||
q = one/8 + j/5
|
|
||||||
# Mathematica N[EllipticThetaPrime[1, 1/7 + I/3, 1/8 + I/5], 35]
|
|
||||||
res = mpf('1.5555195883277196036090928995803201') - \
|
|
||||||
mpf('0.02439761276895463494054149673076275') * j
|
|
||||||
result = jtheta(1, z, q, 1)
|
|
||||||
assert(mpc_ae(result, res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticThetaPrime[2, 1/7 + I/3, 1/8 + I/5], 35]
|
|
||||||
res = mpf('0.19825296689470982332701283509685662') - \
|
|
||||||
mpf('0.46038135182282106983251742935250009') * j
|
|
||||||
result = jtheta(2, z, q, 1)
|
|
||||||
assert(mpc_ae(result, res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticThetaPrime[3, 1/7 + I/3, 1/8 + I/5], 35]
|
|
||||||
res = mpf('0.36492498415476212680896699407390026') - \
|
|
||||||
mpf('0.57743812698666990209897034525640369') * j
|
|
||||||
result = jtheta(3, z, q, 1)
|
|
||||||
assert(mpc_ae(result, res))
|
|
||||||
|
|
||||||
# Mathematica N[EllipticThetaPrime[4, 1/7 + I/3, 1/8 + I/5], 35]
|
|
||||||
res = mpf('-0.38936892528126996010818803742007352') + \
|
|
||||||
mpf('0.66549886179739128256269617407313625') * j
|
|
||||||
result = jtheta(4, z, q, 1)
|
|
||||||
assert(mpc_ae(result, res))
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
q = (one*random.random() + j*random.random())/2
|
|
||||||
# identity in Wittaker, Watson &21.41
|
|
||||||
a = jtheta(1, 0, q, 1)
|
|
||||||
b = jtheta(2, 0, q)*jtheta(3, 0, q)*jtheta(4, 0, q)
|
|
||||||
assert(a.ae(b))
|
|
||||||
|
|
||||||
# test higher derivatives
|
|
||||||
mp.dps = 20
|
|
||||||
for q,z in [(one/3, one/5), (one/3 + j/8, one/5),
|
|
||||||
(one/3, one/5 + j/8), (one/3 + j/7, one/5 + j/8)]:
|
|
||||||
for n in [1, 2, 3, 4]:
|
|
||||||
r = jtheta(n, z, q, 2)
|
|
||||||
r1 = diff(lambda zz: jtheta(n, zz, q), z, n=2)
|
|
||||||
assert r.ae(r1)
|
|
||||||
r = jtheta(n, z, q, 3)
|
|
||||||
r1 = diff(lambda zz: jtheta(n, zz, q), z, n=3)
|
|
||||||
assert r.ae(r1)
|
|
||||||
|
|
||||||
# identity in Wittaker, Watson &21.41
|
|
||||||
q = one/3
|
|
||||||
z = zero
|
|
||||||
a = [0]*5
|
|
||||||
a[1] = jtheta(1, z, q, 3)/jtheta(1, z, q, 1)
|
|
||||||
for n in [2,3,4]:
|
|
||||||
a[n] = jtheta(n, z, q, 2)/jtheta(n, z, q)
|
|
||||||
equality = a[2] + a[3] + a[4] - a[1]
|
|
||||||
assert(equality.ae(0))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_jsn():
|
|
||||||
"""
|
|
||||||
Test some special cases of the sn(z, q) function.
|
|
||||||
"""
|
|
||||||
mp.dps = 100
|
|
||||||
|
|
||||||
# trival case
|
|
||||||
result = jsn(zero, zero)
|
|
||||||
assert(result == zero)
|
|
||||||
|
|
||||||
# Abramowitz Table 16.5
|
|
||||||
#
|
|
||||||
# sn(0, m) = 0
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
qstring = str(random.random())
|
|
||||||
q = mpf(qstring)
|
|
||||||
|
|
||||||
equality = jsn(zero, q)
|
|
||||||
assert(equality.ae(0))
|
|
||||||
|
|
||||||
# Abramowitz Table 16.6.1
|
|
||||||
#
|
|
||||||
# sn(z, 0) = sin(z), m == 0
|
|
||||||
#
|
|
||||||
# sn(z, 1) = tanh(z), m == 1
|
|
||||||
#
|
|
||||||
# It would be nice to test these, but I find that they run
|
|
||||||
# in to numerical trouble. I'm currently treating as a boundary
|
|
||||||
# case for sn function.
|
|
||||||
|
|
||||||
mp.dps = 25
|
|
||||||
arg = one/10
|
|
||||||
#N[JacobiSN[1/10, 2^-100], 25]
|
|
||||||
res = mpf('0.09983341664682815230681420')
|
|
||||||
m = ldexp(one, -100)
|
|
||||||
result = jsn(arg, m)
|
|
||||||
assert(result.ae(res))
|
|
||||||
|
|
||||||
# N[JacobiSN[1/10, 1/10], 25]
|
|
||||||
res = mpf('0.09981686718599080096451168')
|
|
||||||
result = jsn(arg, arg)
|
|
||||||
assert(result.ae(res))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_jcn():
|
|
||||||
"""
|
|
||||||
Test some special cases of the cn(z, q) function.
|
|
||||||
"""
|
|
||||||
mp.dps = 100
|
|
||||||
|
|
||||||
# Abramowitz Table 16.5
|
|
||||||
# cn(0, q) = 1
|
|
||||||
qstring = str(random.random())
|
|
||||||
q = mpf(qstring)
|
|
||||||
cn = jcn(zero, q)
|
|
||||||
assert(cn.ae(one))
|
|
||||||
|
|
||||||
# Abramowitz Table 16.6.2
|
|
||||||
#
|
|
||||||
# cn(u, 0) = cos(u), m == 0
|
|
||||||
#
|
|
||||||
# cn(u, 1) = sech(z), m == 1
|
|
||||||
#
|
|
||||||
# It would be nice to test these, but I find that they run
|
|
||||||
# in to numerical trouble. I'm currently treating as a boundary
|
|
||||||
# case for cn function.
|
|
||||||
|
|
||||||
mp.dps = 25
|
|
||||||
arg = one/10
|
|
||||||
m = ldexp(one, -100)
|
|
||||||
#N[JacobiCN[1/10, 2^-100], 25]
|
|
||||||
res = mpf('0.9950041652780257660955620')
|
|
||||||
result = jcn(arg, m)
|
|
||||||
assert(result.ae(res))
|
|
||||||
|
|
||||||
# N[JacobiCN[1/10, 1/10], 25]
|
|
||||||
res = mpf('0.9950058256237368748520459')
|
|
||||||
result = jcn(arg, arg)
|
|
||||||
assert(result.ae(res))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_jdn():
|
|
||||||
"""
|
|
||||||
Test some special cases of the dn(z, q) function.
|
|
||||||
"""
|
|
||||||
mp.dps = 100
|
|
||||||
|
|
||||||
# Abramowitz Table 16.5
|
|
||||||
# dn(0, q) = 1
|
|
||||||
mstring = str(random.random())
|
|
||||||
m = mpf(mstring)
|
|
||||||
|
|
||||||
dn = jdn(zero, m)
|
|
||||||
assert(dn.ae(one))
|
|
||||||
|
|
||||||
mp.dps = 25
|
|
||||||
# N[JacobiDN[1/10, 1/10], 25]
|
|
||||||
res = mpf('0.9995017055025556219713297')
|
|
||||||
arg = one/10
|
|
||||||
result = jdn(arg, arg)
|
|
||||||
assert(result.ae(res))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
|
|
||||||
def test_sn_cn_dn_identities():
|
|
||||||
"""
|
|
||||||
Tests the some of the jacobi elliptic function identities found
|
|
||||||
on Mathworld. Haven't found in Abramowitz.
|
|
||||||
"""
|
|
||||||
mp.dps = 100
|
|
||||||
N = 5
|
|
||||||
for i in range(N):
|
|
||||||
qstring = str(random.random())
|
|
||||||
q = mpf(qstring)
|
|
||||||
zstring = str(100*random.random())
|
|
||||||
z = mpf(zstring)
|
|
||||||
|
|
||||||
# MathWorld
|
|
||||||
# sn(z, q)**2 + cn(z, q)**2 == 1
|
|
||||||
term1 = jsn(z, q)**2
|
|
||||||
term2 = jcn(z, q)**2
|
|
||||||
equality = one - term1 - term2
|
|
||||||
assert(equality.ae(0))
|
|
||||||
|
|
||||||
# MathWorld
|
|
||||||
# k**2 * sn(z, m)**2 + dn(z, m)**2 == 1
|
|
||||||
for i in range(N):
|
|
||||||
mstring = str(random.random())
|
|
||||||
m = mpf(qstring)
|
|
||||||
k = m.sqrt()
|
|
||||||
zstring = str(10*random.random())
|
|
||||||
z = mpf(zstring)
|
|
||||||
term1 = k**2 * jsn(z, m)**2
|
|
||||||
term2 = jdn(z, m)**2
|
|
||||||
equality = one - term1 - term2
|
|
||||||
assert(equality.ae(0))
|
|
||||||
|
|
||||||
|
|
||||||
for i in range(N):
|
|
||||||
mstring = str(random.random())
|
|
||||||
m = mpf(mstring)
|
|
||||||
k = m.sqrt()
|
|
||||||
zstring = str(random.random())
|
|
||||||
z = mpf(zstring)
|
|
||||||
|
|
||||||
# MathWorld
|
|
||||||
# k**2 * cn(z, m)**2 + (1 - k**2) = dn(z, m)**2
|
|
||||||
term1 = k**2 * jcn(z, m)**2
|
|
||||||
term2 = 1 - k**2
|
|
||||||
term3 = jdn(z, m)**2
|
|
||||||
equality = term3 - term1 - term2
|
|
||||||
assert(equality.ae(0))
|
|
||||||
|
|
||||||
K = ellipk(k**2)
|
|
||||||
# Abramowitz Table 16.5
|
|
||||||
# sn(K, m) = 1; K is K(k), first complete elliptic integral
|
|
||||||
r = jsn(K, m)
|
|
||||||
assert(r.ae(one))
|
|
||||||
|
|
||||||
# Abramowitz Table 16.5
|
|
||||||
# cn(K, q) = 0; K is K(k), first complete elliptic integral
|
|
||||||
equality = jcn(K, m)
|
|
||||||
assert(equality.ae(0))
|
|
||||||
|
|
||||||
# Abramowitz Table 16.6.3
|
|
||||||
# dn(z, 0) = 1, m == 0
|
|
||||||
z = m
|
|
||||||
value = jdn(z, zero)
|
|
||||||
assert(value.ae(one))
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_sn_cn_dn_complex():
|
|
||||||
mp.dps = 30
|
|
||||||
# N[JacobiSN[1/4 + I/8, 1/3 + I/7], 35] in Mathematica
|
|
||||||
res = mpf('0.2495674401066275492326652143537') + \
|
|
||||||
mpf('0.12017344422863833381301051702823') * j
|
|
||||||
u = mpf(1)/4 + j/8
|
|
||||||
m = mpf(1)/3 + j/7
|
|
||||||
r = jsn(u, m)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
|
|
||||||
#N[JacobiCN[1/4 + I/8, 1/3 + I/7], 35]
|
|
||||||
res = mpf('0.9762691700944007312693721148331') - \
|
|
||||||
mpf('0.0307203994181623243583169154824')*j
|
|
||||||
r = jcn(u, m)
|
|
||||||
#assert r.real.ae(res.real)
|
|
||||||
#assert r.imag.ae(res.imag)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
|
|
||||||
#N[JacobiDN[1/4 + I/8, 1/3 + I/7], 35]
|
|
||||||
res = mpf('0.99639490163039577560547478589753039') - \
|
|
||||||
mpf('0.01346296520008176393432491077244994')*j
|
|
||||||
r = jdn(u, m)
|
|
||||||
assert(mpc_ae(r, res))
|
|
||||||
mp.dps = 15
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,882 +0,0 @@
|
||||||
from mpmath.libmp import *
|
|
||||||
from mpmath import *
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import math
|
|
||||||
import cmath
|
|
||||||
|
|
||||||
def mpc_ae(a, b, eps=eps):
|
|
||||||
res = True
|
|
||||||
res = res and a.real.ae(b.real, eps)
|
|
||||||
res = res and a.imag.ae(b.imag, eps)
|
|
||||||
return res
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------
|
|
||||||
# Constants and functions
|
|
||||||
#
|
|
||||||
|
|
||||||
tpi = "3.1415926535897932384626433832795028841971693993751058209749445923078\
|
|
||||||
1640628620899862803482534211706798"
|
|
||||||
te = "2.71828182845904523536028747135266249775724709369995957496696762772407\
|
|
||||||
663035354759457138217852516642743"
|
|
||||||
tdegree = "0.017453292519943295769236907684886127134428718885417254560971914\
|
|
||||||
4017100911460344944368224156963450948221"
|
|
||||||
teuler = "0.5772156649015328606065120900824024310421593359399235988057672348\
|
|
||||||
84867726777664670936947063291746749516"
|
|
||||||
tln2 = "0.693147180559945309417232121458176568075500134360255254120680009493\
|
|
||||||
393621969694715605863326996418687542"
|
|
||||||
tln10 = "2.30258509299404568401799145468436420760110148862877297603332790096\
|
|
||||||
757260967735248023599720508959829834"
|
|
||||||
tcatalan = "0.91596559417721901505460351493238411077414937428167213426649811\
|
|
||||||
9621763019776254769479356512926115106249"
|
|
||||||
tkhinchin = "2.6854520010653064453097148354817956938203822939944629530511523\
|
|
||||||
4555721885953715200280114117493184769800"
|
|
||||||
tglaisher = "1.2824271291006226368753425688697917277676889273250011920637400\
|
|
||||||
2174040630885882646112973649195820237439420646"
|
|
||||||
tapery = "1.2020569031595942853997381615114499907649862923404988817922715553\
|
|
||||||
4183820578631309018645587360933525815"
|
|
||||||
tphi = "1.618033988749894848204586834365638117720309179805762862135448622705\
|
|
||||||
26046281890244970720720418939113748475"
|
|
||||||
tmertens = "0.26149721284764278375542683860869585905156664826119920619206421\
|
|
||||||
3924924510897368209714142631434246651052"
|
|
||||||
ttwinprime = "0.660161815846869573927812110014555778432623360284733413319448\
|
|
||||||
423335405642304495277143760031413839867912"
|
|
||||||
|
|
||||||
def test_constants():
|
|
||||||
for prec in [3, 7, 10, 15, 20, 37, 80, 100, 29]:
|
|
||||||
mp.dps = prec
|
|
||||||
assert pi == mpf(tpi)
|
|
||||||
assert e == mpf(te)
|
|
||||||
assert degree == mpf(tdegree)
|
|
||||||
assert euler == mpf(teuler)
|
|
||||||
assert ln2 == mpf(tln2)
|
|
||||||
assert ln10 == mpf(tln10)
|
|
||||||
assert catalan == mpf(tcatalan)
|
|
||||||
assert khinchin == mpf(tkhinchin)
|
|
||||||
assert glaisher == mpf(tglaisher)
|
|
||||||
assert phi == mpf(tphi)
|
|
||||||
if prec < 50:
|
|
||||||
assert mertens == mpf(tmertens)
|
|
||||||
assert twinprime == mpf(ttwinprime)
|
|
||||||
mp.dps = 15
|
|
||||||
assert pi >= -1
|
|
||||||
assert pi > 2
|
|
||||||
assert pi > 3
|
|
||||||
assert pi < 4
|
|
||||||
|
|
||||||
def test_exact_sqrts():
|
|
||||||
for i in range(20000):
|
|
||||||
assert sqrt(mpf(i*i)) == i
|
|
||||||
random.seed(1)
|
|
||||||
for prec in [100, 300, 1000, 10000]:
|
|
||||||
mp.dps = prec
|
|
||||||
for i in range(20):
|
|
||||||
A = random.randint(10**(prec//2-2), 10**(prec//2-1))
|
|
||||||
assert sqrt(mpf(A*A)) == A
|
|
||||||
mp.dps = 15
|
|
||||||
for i in range(100):
|
|
||||||
for a in [1, 8, 25, 112307]:
|
|
||||||
assert sqrt(mpf((a*a, 2*i))) == mpf((a, i))
|
|
||||||
assert sqrt(mpf((a*a, -2*i))) == mpf((a, -i))
|
|
||||||
|
|
||||||
def test_sqrt_rounding():
|
|
||||||
for i in [2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15]:
|
|
||||||
i = from_int(i)
|
|
||||||
for dps in [7, 15, 83, 106, 2000]:
|
|
||||||
mp.dps = dps
|
|
||||||
a = mpf_pow_int(mpf_sqrt(i, mp.prec, round_down), 2, mp.prec, round_down)
|
|
||||||
b = mpf_pow_int(mpf_sqrt(i, mp.prec, round_up), 2, mp.prec, round_up)
|
|
||||||
assert mpf_lt(a, i)
|
|
||||||
assert mpf_gt(b, i)
|
|
||||||
random.seed(1234)
|
|
||||||
prec = 100
|
|
||||||
for rnd in [round_down, round_nearest, round_ceiling]:
|
|
||||||
for i in range(100):
|
|
||||||
a = mpf_rand(prec)
|
|
||||||
b = mpf_mul(a, a)
|
|
||||||
assert mpf_sqrt(b, prec, rnd) == a
|
|
||||||
# Test some extreme cases
|
|
||||||
mp.dps = 100
|
|
||||||
a = mpf(9) + 1e-90
|
|
||||||
b = mpf(9) - 1e-90
|
|
||||||
mp.dps = 15
|
|
||||||
assert sqrt(a, rounding='d') == 3
|
|
||||||
assert sqrt(a, rounding='n') == 3
|
|
||||||
assert sqrt(a, rounding='u') > 3
|
|
||||||
assert sqrt(b, rounding='d') < 3
|
|
||||||
assert sqrt(b, rounding='n') == 3
|
|
||||||
assert sqrt(b, rounding='u') == 3
|
|
||||||
# A worst case, from the MPFR test suite
|
|
||||||
assert sqrt(mpf('7.0503726185518891')) == mpf('2.655253776675949')
|
|
||||||
|
|
||||||
def test_float_sqrt():
|
|
||||||
mp.dps = 15
|
|
||||||
# These should round identically
|
|
||||||
for x in [0, 1e-7, 0.1, 0.5, 1, 2, 3, 4, 5, 0.333, 76.19]:
|
|
||||||
assert sqrt(mpf(x)) == float(x)**0.5
|
|
||||||
assert sqrt(-1) == 1j
|
|
||||||
assert sqrt(-2).ae(cmath.sqrt(-2))
|
|
||||||
assert sqrt(-3).ae(cmath.sqrt(-3))
|
|
||||||
assert sqrt(-100).ae(cmath.sqrt(-100))
|
|
||||||
assert sqrt(1j).ae(cmath.sqrt(1j))
|
|
||||||
assert sqrt(-1j).ae(cmath.sqrt(-1j))
|
|
||||||
assert sqrt(math.pi + math.e*1j).ae(cmath.sqrt(math.pi + math.e*1j))
|
|
||||||
assert sqrt(math.pi - math.e*1j).ae(cmath.sqrt(math.pi - math.e*1j))
|
|
||||||
|
|
||||||
def test_hypot():
|
|
||||||
assert hypot(0, 0) == 0
|
|
||||||
assert hypot(0, 0.33) == mpf(0.33)
|
|
||||||
assert hypot(0.33, 0) == mpf(0.33)
|
|
||||||
assert hypot(-0.33, 0) == mpf(0.33)
|
|
||||||
assert hypot(3, 4) == mpf(5)
|
|
||||||
|
|
||||||
def test_exact_cbrt():
|
|
||||||
for i in range(0, 20000, 200):
|
|
||||||
assert cbrt(mpf(i*i*i)) == i
|
|
||||||
random.seed(1)
|
|
||||||
for prec in [100, 300, 1000, 10000]:
|
|
||||||
mp.dps = prec
|
|
||||||
A = random.randint(10**(prec//2-2), 10**(prec//2-1))
|
|
||||||
assert cbrt(mpf(A*A*A)) == A
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_exp():
|
|
||||||
assert exp(0) == 1
|
|
||||||
assert exp(10000).ae(mpf('8.8068182256629215873e4342'))
|
|
||||||
assert exp(-10000).ae(mpf('1.1354838653147360985e-4343'))
|
|
||||||
a = exp(mpf((1, 8198646019315405L, -53, 53)))
|
|
||||||
assert(a.bc == bitcount(a.man))
|
|
||||||
mp.prec = 67
|
|
||||||
a = exp(mpf((1, 1781864658064754565L, -60, 61)))
|
|
||||||
assert(a.bc == bitcount(a.man))
|
|
||||||
mp.prec = 53
|
|
||||||
assert exp(ln2 * 10).ae(1024)
|
|
||||||
assert exp(2+2j).ae(cmath.exp(2+2j))
|
|
||||||
|
|
||||||
def test_issue_33():
|
|
||||||
mp.dps = 512
|
|
||||||
a = exp(-1)
|
|
||||||
b = exp(1)
|
|
||||||
mp.dps = 15
|
|
||||||
assert (+a).ae(0.36787944117144233)
|
|
||||||
assert (+b).ae(2.7182818284590451)
|
|
||||||
|
|
||||||
def test_log():
|
|
||||||
mp.dps = 15
|
|
||||||
assert log(1) == 0
|
|
||||||
for x in [0.5, 1.5, 2.0, 3.0, 100, 10**50, 1e-50]:
|
|
||||||
assert log(x).ae(math.log(x))
|
|
||||||
assert log(x, x) == 1
|
|
||||||
assert log(1024, 2) == 10
|
|
||||||
assert log(10**1234, 10) == 1234
|
|
||||||
assert log(2+2j).ae(cmath.log(2+2j))
|
|
||||||
# Accuracy near 1
|
|
||||||
assert (log(0.6+0.8j).real*10**17).ae(2.2204460492503131)
|
|
||||||
assert (log(0.6-0.8j).real*10**17).ae(2.2204460492503131)
|
|
||||||
assert (log(0.8-0.6j).real*10**17).ae(2.2204460492503131)
|
|
||||||
assert (log(1+1e-8j).real*10**16).ae(0.5)
|
|
||||||
assert (log(1-1e-8j).real*10**16).ae(0.5)
|
|
||||||
assert (log(-1+1e-8j).real*10**16).ae(0.5)
|
|
||||||
assert (log(-1-1e-8j).real*10**16).ae(0.5)
|
|
||||||
assert (log(1j+1e-8).real*10**16).ae(0.5)
|
|
||||||
assert (log(1j-1e-8).real*10**16).ae(0.5)
|
|
||||||
assert (log(-1j+1e-8).real*10**16).ae(0.5)
|
|
||||||
assert (log(-1j-1e-8).real*10**16).ae(0.5)
|
|
||||||
assert (log(1+1e-40j).real*10**80).ae(0.5)
|
|
||||||
assert (log(1j+1e-40).real*10**80).ae(0.5)
|
|
||||||
# Huge
|
|
||||||
assert log(ldexp(1.234,10**20)).ae(log(2)*1e20)
|
|
||||||
assert log(ldexp(1.234,10**200)).ae(log(2)*1e200)
|
|
||||||
# Some special values
|
|
||||||
assert log(mpc(0,0)) == mpc(-inf,0)
|
|
||||||
assert isnan(log(mpc(nan,0)).real)
|
|
||||||
assert isnan(log(mpc(nan,0)).imag)
|
|
||||||
assert isnan(log(mpc(0,nan)).real)
|
|
||||||
assert isnan(log(mpc(0,nan)).imag)
|
|
||||||
assert isnan(log(mpc(nan,1)).real)
|
|
||||||
assert isnan(log(mpc(nan,1)).imag)
|
|
||||||
assert isnan(log(mpc(1,nan)).real)
|
|
||||||
assert isnan(log(mpc(1,nan)).imag)
|
|
||||||
|
|
||||||
def test_trig_hyperb_basic():
|
|
||||||
for x in (range(100) + range(-100,0)):
|
|
||||||
t = x / 4.1
|
|
||||||
assert cos(mpf(t)).ae(math.cos(t))
|
|
||||||
assert sin(mpf(t)).ae(math.sin(t))
|
|
||||||
assert tan(mpf(t)).ae(math.tan(t))
|
|
||||||
assert cosh(mpf(t)).ae(math.cosh(t))
|
|
||||||
assert sinh(mpf(t)).ae(math.sinh(t))
|
|
||||||
assert tanh(mpf(t)).ae(math.tanh(t))
|
|
||||||
assert sin(1+1j).ae(cmath.sin(1+1j))
|
|
||||||
assert sin(-4-3.6j).ae(cmath.sin(-4-3.6j))
|
|
||||||
assert cos(1+1j).ae(cmath.cos(1+1j))
|
|
||||||
assert cos(-4-3.6j).ae(cmath.cos(-4-3.6j))
|
|
||||||
|
|
||||||
def test_degrees():
|
|
||||||
assert cos(0*degree) == 1
|
|
||||||
assert cos(90*degree).ae(0)
|
|
||||||
assert cos(180*degree).ae(-1)
|
|
||||||
assert cos(270*degree).ae(0)
|
|
||||||
assert cos(360*degree).ae(1)
|
|
||||||
assert sin(0*degree) == 0
|
|
||||||
assert sin(90*degree).ae(1)
|
|
||||||
assert sin(180*degree).ae(0)
|
|
||||||
assert sin(270*degree).ae(-1)
|
|
||||||
assert sin(360*degree).ae(0)
|
|
||||||
|
|
||||||
def random_complexes(N):
|
|
||||||
random.seed(1)
|
|
||||||
a = []
|
|
||||||
for i in range(N):
|
|
||||||
x1 = random.uniform(-10, 10)
|
|
||||||
y1 = random.uniform(-10, 10)
|
|
||||||
x2 = random.uniform(-10, 10)
|
|
||||||
y2 = random.uniform(-10, 10)
|
|
||||||
z1 = complex(x1, y1)
|
|
||||||
z2 = complex(x2, y2)
|
|
||||||
a.append((z1, z2))
|
|
||||||
return a
|
|
||||||
|
|
||||||
def test_complex_powers():
|
|
||||||
for dps in [15, 30, 100]:
|
|
||||||
# Check accuracy for complex square root
|
|
||||||
mp.dps = dps
|
|
||||||
a = mpc(1j)**0.5
|
|
||||||
assert a.real == a.imag == mpf(2)**0.5 / 2
|
|
||||||
mp.dps = 15
|
|
||||||
random.seed(1)
|
|
||||||
for (z1, z2) in random_complexes(100):
|
|
||||||
assert (mpc(z1)**mpc(z2)).ae(z1**z2, 1e-12)
|
|
||||||
assert (e**(-pi*1j)).ae(-1)
|
|
||||||
mp.dps = 50
|
|
||||||
assert (e**(-pi*1j)).ae(-1)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_complex_sqrt_accuracy():
|
|
||||||
def test_mpc_sqrt(lst):
|
|
||||||
for a, b in lst:
|
|
||||||
z = mpc(a + j*b)
|
|
||||||
assert mpc_ae(sqrt(z*z), z)
|
|
||||||
z = mpc(-a + j*b)
|
|
||||||
assert mpc_ae(sqrt(z*z), -z)
|
|
||||||
z = mpc(a - j*b)
|
|
||||||
assert mpc_ae(sqrt(z*z), z)
|
|
||||||
z = mpc(-a - j*b)
|
|
||||||
assert mpc_ae(sqrt(z*z), -z)
|
|
||||||
random.seed(2)
|
|
||||||
N = 10
|
|
||||||
mp.dps = 30
|
|
||||||
dps = mp.dps
|
|
||||||
test_mpc_sqrt([(random.uniform(0, 10),random.uniform(0, 10)) for i in range(N)])
|
|
||||||
test_mpc_sqrt([(i + 0.1, (i + 0.2)*10**i) for i in range(N)])
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_atan():
|
|
||||||
mp.dps = 15
|
|
||||||
assert atan(-2.3).ae(math.atan(-2.3))
|
|
||||||
assert atan(1e-50) == 1e-50
|
|
||||||
assert atan(1e50).ae(pi/2)
|
|
||||||
assert atan(-1e-50) == -1e-50
|
|
||||||
assert atan(-1e50).ae(-pi/2)
|
|
||||||
assert atan(10**1000).ae(pi/2)
|
|
||||||
for dps in [25, 70, 100, 300, 1000]:
|
|
||||||
mp.dps = dps
|
|
||||||
assert (4*atan(1)).ae(pi)
|
|
||||||
mp.dps = 15
|
|
||||||
pi2 = pi/2
|
|
||||||
assert atan(mpc(inf,-1)).ae(pi2)
|
|
||||||
assert atan(mpc(inf,0)).ae(pi2)
|
|
||||||
assert atan(mpc(inf,1)).ae(pi2)
|
|
||||||
assert atan(mpc(1,inf)).ae(pi2)
|
|
||||||
assert atan(mpc(0,inf)).ae(pi2)
|
|
||||||
assert atan(mpc(-1,inf)).ae(-pi2)
|
|
||||||
assert atan(mpc(-inf,1)).ae(-pi2)
|
|
||||||
assert atan(mpc(-inf,0)).ae(-pi2)
|
|
||||||
assert atan(mpc(-inf,-1)).ae(-pi2)
|
|
||||||
assert atan(mpc(-1,-inf)).ae(-pi2)
|
|
||||||
assert atan(mpc(0,-inf)).ae(-pi2)
|
|
||||||
assert atan(mpc(1,-inf)).ae(pi2)
|
|
||||||
|
|
||||||
def test_atan2():
|
|
||||||
mp.dps = 15
|
|
||||||
assert atan2(1,1).ae(pi/4)
|
|
||||||
assert atan2(1,-1).ae(3*pi/4)
|
|
||||||
assert atan2(-1,-1).ae(-3*pi/4)
|
|
||||||
assert atan2(-1,1).ae(-pi/4)
|
|
||||||
assert atan2(-1,0).ae(-pi/2)
|
|
||||||
assert atan2(1,0).ae(pi/2)
|
|
||||||
assert atan2(0,0) == 0
|
|
||||||
assert atan2(inf,0).ae(pi/2)
|
|
||||||
assert atan2(-inf,0).ae(-pi/2)
|
|
||||||
assert isnan(atan2(inf,inf))
|
|
||||||
assert isnan(atan2(-inf,inf))
|
|
||||||
assert isnan(atan2(inf,-inf))
|
|
||||||
assert isnan(atan2(3,nan))
|
|
||||||
assert isnan(atan2(nan,3))
|
|
||||||
assert isnan(atan2(0,nan))
|
|
||||||
assert isnan(atan2(nan,0))
|
|
||||||
assert atan2(0,inf) == 0
|
|
||||||
assert atan2(0,-inf).ae(pi)
|
|
||||||
assert atan2(10,inf) == 0
|
|
||||||
assert atan2(-10,inf) == 0
|
|
||||||
assert atan2(-10,-inf).ae(-pi)
|
|
||||||
assert atan2(10,-inf).ae(pi)
|
|
||||||
assert atan2(inf,10).ae(pi/2)
|
|
||||||
assert atan2(inf,-10).ae(pi/2)
|
|
||||||
assert atan2(-inf,10).ae(-pi/2)
|
|
||||||
assert atan2(-inf,-10).ae(-pi/2)
|
|
||||||
|
|
||||||
def test_areal_inverses():
|
|
||||||
assert asin(mpf(0)) == 0
|
|
||||||
assert asinh(mpf(0)) == 0
|
|
||||||
assert acosh(mpf(1)) == 0
|
|
||||||
assert isinstance(asin(mpf(0.5)), mpf)
|
|
||||||
assert isinstance(asin(mpf(2.0)), mpc)
|
|
||||||
assert isinstance(acos(mpf(0.5)), mpf)
|
|
||||||
assert isinstance(acos(mpf(2.0)), mpc)
|
|
||||||
assert isinstance(atanh(mpf(0.1)), mpf)
|
|
||||||
assert isinstance(atanh(mpf(1.1)), mpc)
|
|
||||||
|
|
||||||
random.seed(1)
|
|
||||||
for i in range(50):
|
|
||||||
x = random.uniform(0, 1)
|
|
||||||
assert asin(mpf(x)).ae(math.asin(x))
|
|
||||||
assert acos(mpf(x)).ae(math.acos(x))
|
|
||||||
|
|
||||||
x = random.uniform(-10, 10)
|
|
||||||
assert asinh(mpf(x)).ae(cmath.asinh(x).real)
|
|
||||||
assert isinstance(asinh(mpf(x)), mpf)
|
|
||||||
x = random.uniform(1, 10)
|
|
||||||
assert acosh(mpf(x)).ae(cmath.acosh(x).real)
|
|
||||||
assert isinstance(acosh(mpf(x)), mpf)
|
|
||||||
x = random.uniform(-10, 0.999)
|
|
||||||
assert isinstance(acosh(mpf(x)), mpc)
|
|
||||||
|
|
||||||
x = random.uniform(-1, 1)
|
|
||||||
assert atanh(mpf(x)).ae(cmath.atanh(x).real)
|
|
||||||
assert isinstance(atanh(mpf(x)), mpf)
|
|
||||||
|
|
||||||
dps = mp.dps
|
|
||||||
mp.dps = 300
|
|
||||||
assert isinstance(asin(0.5), mpf)
|
|
||||||
mp.dps = 1000
|
|
||||||
assert asin(1).ae(pi/2)
|
|
||||||
assert asin(-1).ae(-pi/2)
|
|
||||||
mp.dps = dps
|
|
||||||
|
|
||||||
def test_invhyperb_inaccuracy():
|
|
||||||
mp.dps = 15
|
|
||||||
assert (asinh(1e-5)*10**5).ae(0.99999999998333333)
|
|
||||||
assert (asinh(1e-10)*10**10).ae(1)
|
|
||||||
assert (asinh(1e-50)*10**50).ae(1)
|
|
||||||
assert (asinh(-1e-5)*10**5).ae(-0.99999999998333333)
|
|
||||||
assert (asinh(-1e-10)*10**10).ae(-1)
|
|
||||||
assert (asinh(-1e-50)*10**50).ae(-1)
|
|
||||||
assert asinh(10**20).ae(46.744849040440862)
|
|
||||||
assert asinh(-10**20).ae(-46.744849040440862)
|
|
||||||
assert (tanh(1e-10)*10**10).ae(1)
|
|
||||||
assert (tanh(-1e-10)*10**10).ae(-1)
|
|
||||||
assert (atanh(1e-10)*10**10).ae(1)
|
|
||||||
assert (atanh(-1e-10)*10**10).ae(-1)
|
|
||||||
|
|
||||||
def test_complex_functions():
|
|
||||||
for x in (range(10) + range(-10,0)):
|
|
||||||
for y in (range(10) + range(-10,0)):
|
|
||||||
z = complex(x, y)/4.3 + 0.01j
|
|
||||||
assert exp(mpc(z)).ae(cmath.exp(z))
|
|
||||||
assert log(mpc(z)).ae(cmath.log(z))
|
|
||||||
assert cos(mpc(z)).ae(cmath.cos(z))
|
|
||||||
assert sin(mpc(z)).ae(cmath.sin(z))
|
|
||||||
assert tan(mpc(z)).ae(cmath.tan(z))
|
|
||||||
assert sinh(mpc(z)).ae(cmath.sinh(z))
|
|
||||||
assert cosh(mpc(z)).ae(cmath.cosh(z))
|
|
||||||
assert tanh(mpc(z)).ae(cmath.tanh(z))
|
|
||||||
|
|
||||||
def test_complex_inverse_functions():
|
|
||||||
for (z1, z2) in random_complexes(30):
|
|
||||||
# apparently cmath uses a different branch, so we
|
|
||||||
# can't use it for comparison
|
|
||||||
assert sinh(asinh(z1)).ae(z1)
|
|
||||||
#
|
|
||||||
assert acosh(z1).ae(cmath.acosh(z1))
|
|
||||||
assert atanh(z1).ae(cmath.atanh(z1))
|
|
||||||
assert atan(z1).ae(cmath.atan(z1))
|
|
||||||
# the reason we set a big eps here is that the cmath
|
|
||||||
# functions are inaccurate
|
|
||||||
assert asin(z1).ae(cmath.asin(z1), rel_eps=1e-12)
|
|
||||||
assert acos(z1).ae(cmath.acos(z1), rel_eps=1e-12)
|
|
||||||
one = mpf(1)
|
|
||||||
for i in range(-9, 10, 3):
|
|
||||||
for k in range(-9, 10, 3):
|
|
||||||
a = 0.9*j*10**k + 0.8*one*10**i
|
|
||||||
b = cos(acos(a))
|
|
||||||
assert b.ae(a)
|
|
||||||
b = sin(asin(a))
|
|
||||||
assert b.ae(a)
|
|
||||||
one = mpf(1)
|
|
||||||
err = 2*10**-15
|
|
||||||
for i in range(-9, 9, 3):
|
|
||||||
for k in range(-9, 9, 3):
|
|
||||||
a = -0.9*10**k + j*0.8*one*10**i
|
|
||||||
b = cosh(acosh(a))
|
|
||||||
assert b.ae(a, err)
|
|
||||||
b = sinh(asinh(a))
|
|
||||||
assert b.ae(a, err)
|
|
||||||
|
|
||||||
def test_reciprocal_functions():
|
|
||||||
assert sec(3).ae(-1.01010866590799375)
|
|
||||||
assert csc(3).ae(7.08616739573718592)
|
|
||||||
assert cot(3).ae(-7.01525255143453347)
|
|
||||||
assert sech(3).ae(0.0993279274194332078)
|
|
||||||
assert csch(3).ae(0.0998215696688227329)
|
|
||||||
assert coth(3).ae(1.00496982331368917)
|
|
||||||
assert asec(3).ae(1.23095941734077468)
|
|
||||||
assert acsc(3).ae(0.339836909454121937)
|
|
||||||
assert acot(3).ae(0.321750554396642193)
|
|
||||||
assert asech(0.5).ae(1.31695789692481671)
|
|
||||||
assert acsch(3).ae(0.327450150237258443)
|
|
||||||
assert acoth(3).ae(0.346573590279972655)
|
|
||||||
|
|
||||||
def test_ldexp():
|
|
||||||
mp.dps = 15
|
|
||||||
assert ldexp(mpf(2.5), 0) == 2.5
|
|
||||||
assert ldexp(mpf(2.5), -1) == 1.25
|
|
||||||
assert ldexp(mpf(2.5), 2) == 10
|
|
||||||
assert ldexp(mpf('inf'), 3) == mpf('inf')
|
|
||||||
|
|
||||||
def test_frexp():
|
|
||||||
mp.dps = 15
|
|
||||||
assert frexp(0) == (0.0, 0)
|
|
||||||
assert frexp(9) == (0.5625, 4)
|
|
||||||
assert frexp(1) == (0.5, 1)
|
|
||||||
assert frexp(0.2) == (0.8, -2)
|
|
||||||
assert frexp(1000) == (0.9765625, 10)
|
|
||||||
|
|
||||||
def test_aliases():
|
|
||||||
assert ln(7) == log(7)
|
|
||||||
assert log10(3.75) == log(3.75,10)
|
|
||||||
assert degrees(5.6) == 5.6 / degree
|
|
||||||
assert radians(5.6) == 5.6 * degree
|
|
||||||
assert power(-1,0.5) == j
|
|
||||||
assert modf(25,7) == 4.0 and isinstance(modf(25,7), mpf)
|
|
||||||
|
|
||||||
def test_arg_sign():
|
|
||||||
assert arg(3) == 0
|
|
||||||
assert arg(-3).ae(pi)
|
|
||||||
assert arg(j).ae(pi/2)
|
|
||||||
assert arg(-j).ae(-pi/2)
|
|
||||||
assert arg(0) == 0
|
|
||||||
assert isnan(atan2(3,nan))
|
|
||||||
assert isnan(atan2(nan,3))
|
|
||||||
assert isnan(atan2(0,nan))
|
|
||||||
assert isnan(atan2(nan,0))
|
|
||||||
assert isnan(atan2(nan,nan))
|
|
||||||
assert arg(inf) == 0
|
|
||||||
assert arg(-inf).ae(pi)
|
|
||||||
assert isnan(arg(nan))
|
|
||||||
#assert arg(inf*j).ae(pi/2)
|
|
||||||
assert sign(0) == 0
|
|
||||||
assert sign(3) == 1
|
|
||||||
assert sign(-3) == -1
|
|
||||||
assert sign(inf) == 1
|
|
||||||
assert sign(-inf) == -1
|
|
||||||
assert isnan(sign(nan))
|
|
||||||
assert sign(j) == j
|
|
||||||
assert sign(-3*j) == -j
|
|
||||||
assert sign(1+j).ae((1+j)/sqrt(2))
|
|
||||||
|
|
||||||
def test_misc_bugs():
|
|
||||||
# test that this doesn't raise an exception
|
|
||||||
mp.dps = 1000
|
|
||||||
log(1302)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_arange():
|
|
||||||
assert arange(10) == [mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0'),
|
|
||||||
mpf('4.0'), mpf('5.0'), mpf('6.0'), mpf('7.0'),
|
|
||||||
mpf('8.0'), mpf('9.0')]
|
|
||||||
assert arange(-5, 5) == [mpf('-5.0'), mpf('-4.0'), mpf('-3.0'),
|
|
||||||
mpf('-2.0'), mpf('-1.0'), mpf('0.0'),
|
|
||||||
mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
|
|
||||||
assert arange(0, 1, 0.1) == [mpf('0.0'), mpf('0.10000000000000001'),
|
|
||||||
mpf('0.20000000000000001'),
|
|
||||||
mpf('0.30000000000000004'),
|
|
||||||
mpf('0.40000000000000002'),
|
|
||||||
mpf('0.5'), mpf('0.60000000000000009'),
|
|
||||||
mpf('0.70000000000000007'),
|
|
||||||
mpf('0.80000000000000004'),
|
|
||||||
mpf('0.90000000000000002')]
|
|
||||||
assert arange(17, -9, -3) == [mpf('17.0'), mpf('14.0'), mpf('11.0'),
|
|
||||||
mpf('8.0'), mpf('5.0'), mpf('2.0'),
|
|
||||||
mpf('-1.0'), mpf('-4.0'), mpf('-7.0')]
|
|
||||||
assert arange(0.2, 0.1, -0.1) == [mpf('0.20000000000000001')]
|
|
||||||
assert arange(0) == []
|
|
||||||
assert arange(1000, -1) == []
|
|
||||||
assert arange(-1.23, 3.21, -0.0000001) == []
|
|
||||||
|
|
||||||
def test_linspace():
|
|
||||||
assert linspace(2, 9, 7) == [mpf('2.0'), mpf('3.166666666666667'),
|
|
||||||
mpf('4.3333333333333339'), mpf('5.5'), mpf('6.666666666666667'),
|
|
||||||
mpf('7.8333333333333339'), mpf('9.0')] == linspace(mpi(2, 9), 7)
|
|
||||||
assert linspace(2, 9, 7, endpoint=0) == [mpf('2.0'), mpf('3.0'), mpf('4.0'),
|
|
||||||
mpf('5.0'), mpf('6.0'), mpf('7.0'), mpf('8.0')]
|
|
||||||
assert linspace(2, 7, 1) == [mpf(2)]
|
|
||||||
|
|
||||||
def test_float_cbrt():
|
|
||||||
mp.dps = 30
|
|
||||||
for a in arange(0,10,0.1):
|
|
||||||
assert cbrt(a*a*a).ae(a, eps)
|
|
||||||
assert cbrt(-1).ae(0.5 + j*sqrt(3)/2)
|
|
||||||
one_third = mpf(1)/3
|
|
||||||
for a in arange(0,10,2.7) + [0.1 + 10**5]:
|
|
||||||
a = mpc(a + 1.1j)
|
|
||||||
r1 = cbrt(a)
|
|
||||||
mp.dps += 10
|
|
||||||
r2 = pow(a, one_third)
|
|
||||||
mp.dps -= 10
|
|
||||||
assert r1.ae(r2, eps)
|
|
||||||
mp.dps = 100
|
|
||||||
for n in range(100, 301, 100):
|
|
||||||
w = 10**n + j*10**-3
|
|
||||||
z = w*w*w
|
|
||||||
r = cbrt(z)
|
|
||||||
assert mpc_ae(r, w, eps)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_root():
|
|
||||||
mp.dps = 30
|
|
||||||
random.seed(1)
|
|
||||||
a = random.randint(0, 10000)
|
|
||||||
p = a*a*a
|
|
||||||
r = nthroot(mpf(p), 3)
|
|
||||||
assert r == a
|
|
||||||
for n in range(4, 10):
|
|
||||||
p = p*a
|
|
||||||
assert nthroot(mpf(p), n) == a
|
|
||||||
mp.dps = 40
|
|
||||||
for n in range(10, 5000, 100):
|
|
||||||
for a in [random.random()*10000, random.random()*10**100]:
|
|
||||||
r = nthroot(a, n)
|
|
||||||
r1 = pow(a, mpf(1)/n)
|
|
||||||
assert r.ae(r1)
|
|
||||||
r = nthroot(a, -n)
|
|
||||||
r1 = pow(a, -mpf(1)/n)
|
|
||||||
assert r.ae(r1)
|
|
||||||
# XXX: this is broken right now
|
|
||||||
# tests for nthroot rounding
|
|
||||||
for rnd in ['nearest', 'up', 'down']:
|
|
||||||
mp.rounding = rnd
|
|
||||||
for n in [-5, -3, 3, 5]:
|
|
||||||
prec = 50
|
|
||||||
for i in xrange(10):
|
|
||||||
mp.prec = prec
|
|
||||||
a = rand()
|
|
||||||
mp.prec = 2*prec
|
|
||||||
b = a**n
|
|
||||||
mp.prec = prec
|
|
||||||
r = nthroot(b, n)
|
|
||||||
assert r == a
|
|
||||||
mp.dps = 30
|
|
||||||
for n in range(3, 21):
|
|
||||||
a = (random.random() + j*random.random())
|
|
||||||
assert nthroot(a, n).ae(pow(a, mpf(1)/n))
|
|
||||||
assert mpc_ae(nthroot(a, n), pow(a, mpf(1)/n))
|
|
||||||
a = (random.random()*10**100 + j*random.random())
|
|
||||||
r = nthroot(a, n)
|
|
||||||
mp.dps += 4
|
|
||||||
r1 = pow(a, mpf(1)/n)
|
|
||||||
mp.dps -= 4
|
|
||||||
assert r.ae(r1)
|
|
||||||
assert mpc_ae(r, r1, eps)
|
|
||||||
r = nthroot(a, -n)
|
|
||||||
mp.dps += 4
|
|
||||||
r1 = pow(a, -mpf(1)/n)
|
|
||||||
mp.dps -= 4
|
|
||||||
assert r.ae(r1)
|
|
||||||
assert mpc_ae(r, r1, eps)
|
|
||||||
mp.dps = 15
|
|
||||||
assert nthroot(4, 1) == 4
|
|
||||||
assert nthroot(4, 0) == 1
|
|
||||||
assert nthroot(4, -1) == 0.25
|
|
||||||
assert nthroot(inf, 1) == inf
|
|
||||||
assert nthroot(inf, 2) == inf
|
|
||||||
assert nthroot(inf, 3) == inf
|
|
||||||
assert nthroot(inf, -1) == 0
|
|
||||||
assert nthroot(inf, -2) == 0
|
|
||||||
assert nthroot(inf, -3) == 0
|
|
||||||
assert nthroot(j, 1) == j
|
|
||||||
assert nthroot(j, 0) == 1
|
|
||||||
assert nthroot(j, -1) == -j
|
|
||||||
assert isnan(nthroot(nan, 1))
|
|
||||||
assert isnan(nthroot(nan, 0))
|
|
||||||
assert isnan(nthroot(nan, -1))
|
|
||||||
assert isnan(nthroot(inf, 0))
|
|
||||||
assert root(2,3) == nthroot(2,3)
|
|
||||||
assert root(16,4,0) == 2
|
|
||||||
assert root(16,4,1) == 2j
|
|
||||||
assert root(16,4,2) == -2
|
|
||||||
assert root(16,4,3) == -2j
|
|
||||||
assert root(16,4,4) == 2
|
|
||||||
assert root(-125,3,1) == -5
|
|
||||||
|
|
||||||
def test_issue_96():
|
|
||||||
for dps in [20, 80]:
|
|
||||||
mp.dps = dps
|
|
||||||
r = nthroot(mpf('-1e-20'), 4)
|
|
||||||
assert r.ae(mpf(10)**(-5) * (1 + j) * mpf(2)**(-0.5))
|
|
||||||
mp.dps = 80
|
|
||||||
assert nthroot('-1e-3', 4).ae(mpf(10)**(-3./4) * (1 + j)/sqrt(2))
|
|
||||||
assert nthroot('-1e-6', 4).ae((1 + j)/(10 * sqrt(20)))
|
|
||||||
# Check that this doesn't take eternity to compute
|
|
||||||
mp.dps = 20
|
|
||||||
assert nthroot('-1e100000000', 4).ae((1+j)*mpf('1e25000000')/sqrt(2))
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_perturbation_rounding():
|
|
||||||
mp.dps = 100
|
|
||||||
a = pi/10**50
|
|
||||||
b = -pi/10**50
|
|
||||||
c = 1 + a
|
|
||||||
d = 1 + b
|
|
||||||
mp.dps = 15
|
|
||||||
assert exp(a) == 1
|
|
||||||
assert exp(a, rounding='c') > 1
|
|
||||||
assert exp(b, rounding='c') == 1
|
|
||||||
assert exp(a, rounding='f') == 1
|
|
||||||
assert exp(b, rounding='f') < 1
|
|
||||||
assert cos(a) == 1
|
|
||||||
assert cos(a, rounding='c') == 1
|
|
||||||
assert cos(b, rounding='c') == 1
|
|
||||||
assert cos(a, rounding='f') < 1
|
|
||||||
assert cos(b, rounding='f') < 1
|
|
||||||
for f in [sin, atan, asinh, tanh]:
|
|
||||||
assert f(a) == +a
|
|
||||||
assert f(a, rounding='c') > a
|
|
||||||
assert f(a, rounding='f') < a
|
|
||||||
assert f(b) == +b
|
|
||||||
assert f(b, rounding='c') > b
|
|
||||||
assert f(b, rounding='f') < b
|
|
||||||
for f in [asin, tan, sinh, atanh]:
|
|
||||||
assert f(a) == +a
|
|
||||||
assert f(b) == +b
|
|
||||||
assert f(a, rounding='c') > a
|
|
||||||
assert f(b, rounding='c') > b
|
|
||||||
assert f(a, rounding='f') < a
|
|
||||||
assert f(b, rounding='f') < b
|
|
||||||
assert ln(c) == +a
|
|
||||||
assert ln(d) == +b
|
|
||||||
assert ln(c, rounding='c') > a
|
|
||||||
assert ln(c, rounding='f') < a
|
|
||||||
assert ln(d, rounding='c') > b
|
|
||||||
assert ln(d, rounding='f') < b
|
|
||||||
assert cosh(a) == 1
|
|
||||||
assert cosh(b) == 1
|
|
||||||
assert cosh(a, rounding='c') > 1
|
|
||||||
assert cosh(b, rounding='c') > 1
|
|
||||||
assert cosh(a, rounding='f') == 1
|
|
||||||
assert cosh(b, rounding='f') == 1
|
|
||||||
|
|
||||||
def test_integer_parts():
|
|
||||||
assert floor(3.2) == 3
|
|
||||||
assert ceil(3.2) == 4
|
|
||||||
assert floor(3.2+5j) == 3+5j
|
|
||||||
assert ceil(3.2+5j) == 4+5j
|
|
||||||
|
|
||||||
def test_complex_parts():
|
|
||||||
assert fabs('3') == 3
|
|
||||||
assert fabs(3+4j) == 5
|
|
||||||
assert re(3) == 3
|
|
||||||
assert re(1+4j) == 1
|
|
||||||
assert im(3) == 0
|
|
||||||
assert im(1+4j) == 4
|
|
||||||
assert conj(3) == 3
|
|
||||||
assert conj(3+4j) == 3-4j
|
|
||||||
assert mpf(3).conjugate() == 3
|
|
||||||
|
|
||||||
def test_cospi_sinpi():
|
|
||||||
assert sinpi(0) == 0
|
|
||||||
assert sinpi(0.5) == 1
|
|
||||||
assert sinpi(1) == 0
|
|
||||||
assert sinpi(1.5) == -1
|
|
||||||
assert sinpi(2) == 0
|
|
||||||
assert sinpi(2.5) == 1
|
|
||||||
assert sinpi(-0.5) == -1
|
|
||||||
assert cospi(0) == 1
|
|
||||||
assert cospi(0.5) == 0
|
|
||||||
assert cospi(1) == -1
|
|
||||||
assert cospi(1.5) == 0
|
|
||||||
assert cospi(2) == 1
|
|
||||||
assert cospi(2.5) == 0
|
|
||||||
assert cospi(-0.5) == 0
|
|
||||||
assert cospi(100000000000.25).ae(sqrt(2)/2)
|
|
||||||
a = cospi(2+3j)
|
|
||||||
assert a.real.ae(cos((2+3j)*pi).real)
|
|
||||||
assert a.imag == 0
|
|
||||||
b = sinpi(2+3j)
|
|
||||||
assert b.imag.ae(sin((2+3j)*pi).imag)
|
|
||||||
assert b.real == 0
|
|
||||||
mp.dps = 35
|
|
||||||
x1 = mpf(10000) - mpf('1e-15')
|
|
||||||
x2 = mpf(10000) + mpf('1e-15')
|
|
||||||
x3 = mpf(10000.5) - mpf('1e-15')
|
|
||||||
x4 = mpf(10000.5) + mpf('1e-15')
|
|
||||||
x5 = mpf(10001) - mpf('1e-15')
|
|
||||||
x6 = mpf(10001) + mpf('1e-15')
|
|
||||||
x7 = mpf(10001.5) - mpf('1e-15')
|
|
||||||
x8 = mpf(10001.5) + mpf('1e-15')
|
|
||||||
mp.dps = 15
|
|
||||||
M = 10**15
|
|
||||||
assert (sinpi(x1)*M).ae(-pi)
|
|
||||||
assert (sinpi(x2)*M).ae(pi)
|
|
||||||
assert (cospi(x3)*M).ae(pi)
|
|
||||||
assert (cospi(x4)*M).ae(-pi)
|
|
||||||
assert (sinpi(x5)*M).ae(pi)
|
|
||||||
assert (sinpi(x6)*M).ae(-pi)
|
|
||||||
assert (cospi(x7)*M).ae(-pi)
|
|
||||||
assert (cospi(x8)*M).ae(pi)
|
|
||||||
assert 0.999 < cospi(x1, rounding='d') < 1
|
|
||||||
assert 0.999 < cospi(x2, rounding='d') < 1
|
|
||||||
assert 0.999 < sinpi(x3, rounding='d') < 1
|
|
||||||
assert 0.999 < sinpi(x4, rounding='d') < 1
|
|
||||||
assert -1 < cospi(x5, rounding='d') < -0.999
|
|
||||||
assert -1 < cospi(x6, rounding='d') < -0.999
|
|
||||||
assert -1 < sinpi(x7, rounding='d') < -0.999
|
|
||||||
assert -1 < sinpi(x8, rounding='d') < -0.999
|
|
||||||
assert (sinpi(1e-15)*M).ae(pi)
|
|
||||||
assert (sinpi(-1e-15)*M).ae(-pi)
|
|
||||||
assert cospi(1e-15) == 1
|
|
||||||
assert cospi(1e-15, rounding='d') < 1
|
|
||||||
|
|
||||||
def test_expj():
|
|
||||||
assert expj(0) == 1
|
|
||||||
assert expj(1).ae(exp(j))
|
|
||||||
assert expj(j).ae(exp(-1))
|
|
||||||
assert expj(1+j).ae(exp(j*(1+j)))
|
|
||||||
assert expjpi(0) == 1
|
|
||||||
assert expjpi(1).ae(exp(j*pi))
|
|
||||||
assert expjpi(j).ae(exp(-pi))
|
|
||||||
assert expjpi(1+j).ae(exp(j*pi*(1+j)))
|
|
||||||
assert expjpi(-10**15 * j).ae('2.22579818340535731e+1364376353841841')
|
|
||||||
|
|
||||||
def test_sinc():
|
|
||||||
assert sinc(0) == sincpi(0) == 1
|
|
||||||
assert sinc(inf) == sincpi(inf) == 0
|
|
||||||
assert sinc(-inf) == sincpi(-inf) == 0
|
|
||||||
assert sinc(2).ae(0.45464871341284084770)
|
|
||||||
assert sinc(2+3j).ae(0.4463290318402435457-2.7539470277436474940j)
|
|
||||||
assert sincpi(2) == 0
|
|
||||||
assert sincpi(1.5).ae(-0.212206590789193781)
|
|
||||||
|
|
||||||
def test_fibonacci():
|
|
||||||
mp.dps = 15
|
|
||||||
assert [fibonacci(n) for n in range(-5, 10)] == \
|
|
||||||
[5, -3, 2, -1, 1, 0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
|
|
||||||
assert fib(2.5).ae(1.4893065462657091)
|
|
||||||
assert fib(3+4j).ae(-5248.51130728372 - 14195.962288353j)
|
|
||||||
assert fib(1000).ae(4.3466557686937455e+208)
|
|
||||||
assert str(fib(10**100)) == '6.24499112864607e+2089876402499787337692720892375554168224592399182109535392875613974104853496745963277658556235103534'
|
|
||||||
mp.dps = 2100
|
|
||||||
a = fib(10000)
|
|
||||||
assert a % 10**10 == 9947366875
|
|
||||||
mp.dps = 15
|
|
||||||
assert fibonacci(inf) == inf
|
|
||||||
assert fib(3+0j) == 2
|
|
||||||
|
|
||||||
def test_call_with_dps():
|
|
||||||
mp.dps = 15
|
|
||||||
assert abs(exp(1, dps=30)-e(dps=35)) < 1e-29
|
|
||||||
|
|
||||||
def test_tanh():
|
|
||||||
mp.dps = 15
|
|
||||||
assert tanh(0) == 0
|
|
||||||
assert tanh(inf) == 1
|
|
||||||
assert tanh(-inf) == -1
|
|
||||||
assert isnan(tanh(nan))
|
|
||||||
assert tanh(mpc('inf', '0')) == 1
|
|
||||||
|
|
||||||
def test_atanh():
|
|
||||||
mp.dps = 15
|
|
||||||
assert atanh(0) == 0
|
|
||||||
assert atanh(0.5).ae(0.54930614433405484570)
|
|
||||||
assert atanh(-0.5).ae(-0.54930614433405484570)
|
|
||||||
assert atanh(1) == inf
|
|
||||||
assert atanh(-1) == -inf
|
|
||||||
assert isnan(atanh(nan))
|
|
||||||
assert isinstance(atanh(1), mpf)
|
|
||||||
assert isinstance(atanh(-1), mpf)
|
|
||||||
# Limits at infinity
|
|
||||||
jpi2 = j*pi/2
|
|
||||||
assert atanh(inf).ae(-jpi2)
|
|
||||||
assert atanh(-inf).ae(jpi2)
|
|
||||||
assert atanh(mpc(inf,-1)).ae(-jpi2)
|
|
||||||
assert atanh(mpc(inf,0)).ae(-jpi2)
|
|
||||||
assert atanh(mpc(inf,1)).ae(jpi2)
|
|
||||||
assert atanh(mpc(1,inf)).ae(jpi2)
|
|
||||||
assert atanh(mpc(0,inf)).ae(jpi2)
|
|
||||||
assert atanh(mpc(-1,inf)).ae(jpi2)
|
|
||||||
assert atanh(mpc(-inf,1)).ae(jpi2)
|
|
||||||
assert atanh(mpc(-inf,0)).ae(jpi2)
|
|
||||||
assert atanh(mpc(-inf,-1)).ae(-jpi2)
|
|
||||||
assert atanh(mpc(-1,-inf)).ae(-jpi2)
|
|
||||||
assert atanh(mpc(0,-inf)).ae(-jpi2)
|
|
||||||
assert atanh(mpc(1,-inf)).ae(-jpi2)
|
|
||||||
|
|
||||||
def test_expm1():
|
|
||||||
mp.dps = 15
|
|
||||||
assert expm1(0) == 0
|
|
||||||
assert expm1(3).ae(exp(3)-1)
|
|
||||||
assert expm1(inf) == inf
|
|
||||||
assert expm1(1e-10)*1e10
|
|
||||||
assert expm1(1e-50).ae(1e-50)
|
|
||||||
assert (expm1(1e-10)*1e10).ae(1.00000000005)
|
|
||||||
|
|
||||||
def test_powm1():
|
|
||||||
mp.dps = 15
|
|
||||||
assert powm1(2,3) == 7
|
|
||||||
assert powm1(-1,2) == 0
|
|
||||||
assert powm1(-1,0) == 0
|
|
||||||
assert powm1(-2,0) == 0
|
|
||||||
assert powm1(3+4j,0) == 0
|
|
||||||
assert powm1(0,1) == -1
|
|
||||||
assert powm1(0,0) == 0
|
|
||||||
assert powm1(1,0) == 0
|
|
||||||
assert powm1(1,2) == 0
|
|
||||||
assert powm1(1,3+4j) == 0
|
|
||||||
assert powm1(1,5) == 0
|
|
||||||
assert powm1(j,4) == 0
|
|
||||||
assert powm1(-j,4) == 0
|
|
||||||
assert (powm1(2,1e-100)*1e100).ae(ln2)
|
|
||||||
assert powm1(2,'1e-100000000000') != 0
|
|
||||||
assert (powm1(fadd(1,1e-100,exact=True), 5)*1e100).ae(5)
|
|
||||||
|
|
||||||
def test_unitroots():
|
|
||||||
assert unitroots(1) == [1]
|
|
||||||
assert unitroots(2) == [1, -1]
|
|
||||||
a, b, c = unitroots(3)
|
|
||||||
assert a == 1
|
|
||||||
assert b.ae(-0.5 + 0.86602540378443864676j)
|
|
||||||
assert c.ae(-0.5 - 0.86602540378443864676j)
|
|
||||||
assert unitroots(1, primitive=True) == [1]
|
|
||||||
assert unitroots(2, primitive=True) == [-1]
|
|
||||||
assert unitroots(3, primitive=True) == unitroots(3)[1:]
|
|
||||||
assert unitroots(4, primitive=True) == [j, -j]
|
|
||||||
assert len(unitroots(17, primitive=True)) == 16
|
|
||||||
assert len(unitroots(16, primitive=True)) == 8
|
|
||||||
|
|
||||||
def test_cyclotomic():
|
|
||||||
mp.dps = 15
|
|
||||||
assert [cyclotomic(n,1) for n in range(31)] == [1,0,2,3,2,5,1,7,2,3,1,11,1,13,1,1,2,17,1,19,1,1,1,23,1,5,1,3,1,29,1]
|
|
||||||
assert [cyclotomic(n,-1) for n in range(31)] == [1,-2,0,1,2,1,3,1,2,1,5,1,1,1,7,1,2,1,3,1,1,1,11,1,1,1,13,1,1,1,1]
|
|
||||||
assert [cyclotomic(n,j) for n in range(21)] == [1,-1+j,1+j,j,0,1,-j,j,2,-j,1,j,3,1,-j,1,2,1,j,j,5]
|
|
||||||
assert [cyclotomic(n,-j) for n in range(21)] == [1,-1-j,1-j,-j,0,1,j,-j,2,j,1,-j,3,1,j,1,2,1,-j,-j,5]
|
|
||||||
assert cyclotomic(1624,j) == 1
|
|
||||||
assert cyclotomic(33600,j) == 1
|
|
||||||
u = sqrt(j, prec=500)
|
|
||||||
assert cyclotomic(8, u).ae(0)
|
|
||||||
assert cyclotomic(30, u).ae(5.8284271247461900976)
|
|
||||||
assert cyclotomic(2040, u).ae(1)
|
|
||||||
assert cyclotomic(0,2.5) == 1
|
|
||||||
assert cyclotomic(1,2.5) == 2.5-1
|
|
||||||
assert cyclotomic(2,2.5) == 2.5+1
|
|
||||||
assert cyclotomic(3,2.5) == 2.5**2 + 2.5 + 1
|
|
||||||
assert cyclotomic(7,2.5) == 406.234375
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,658 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
from mpmath.libmp import round_up, from_float, mpf_zeta_int
|
|
||||||
|
|
||||||
def test_zeta_int_bug():
|
|
||||||
assert mpf_zeta_int(0, 10) == from_float(-0.5)
|
|
||||||
|
|
||||||
def test_bernoulli():
|
|
||||||
assert bernfrac(0) == (1,1)
|
|
||||||
assert bernfrac(1) == (-1,2)
|
|
||||||
assert bernfrac(2) == (1,6)
|
|
||||||
assert bernfrac(3) == (0,1)
|
|
||||||
assert bernfrac(4) == (-1,30)
|
|
||||||
assert bernfrac(5) == (0,1)
|
|
||||||
assert bernfrac(6) == (1,42)
|
|
||||||
assert bernfrac(8) == (-1,30)
|
|
||||||
assert bernfrac(10) == (5,66)
|
|
||||||
assert bernfrac(12) == (-691,2730)
|
|
||||||
assert bernfrac(18) == (43867,798)
|
|
||||||
p, q = bernfrac(228)
|
|
||||||
assert p % 10**10 == 164918161
|
|
||||||
assert q == 625170
|
|
||||||
p, q = bernfrac(1000)
|
|
||||||
assert p % 10**10 == 7950421099
|
|
||||||
assert q == 342999030
|
|
||||||
mp.dps = 15
|
|
||||||
assert bernoulli(0) == 1
|
|
||||||
assert bernoulli(1) == -0.5
|
|
||||||
assert bernoulli(2).ae(1./6)
|
|
||||||
assert bernoulli(3) == 0
|
|
||||||
assert bernoulli(4).ae(-1./30)
|
|
||||||
assert bernoulli(5) == 0
|
|
||||||
assert bernoulli(6).ae(1./42)
|
|
||||||
assert str(bernoulli(10)) == '0.0757575757575758'
|
|
||||||
assert str(bernoulli(234)) == '7.62772793964344e+267'
|
|
||||||
assert str(bernoulli(10**5)) == '-5.82229431461335e+376755'
|
|
||||||
assert str(bernoulli(10**8+2)) == '1.19570355039953e+676752584'
|
|
||||||
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(bernoulli(10)) == '0.075757575757575757575757575757575757575757575757576'
|
|
||||||
assert str(bernoulli(234)) == '7.6277279396434392486994969020496121553385863373331e+267'
|
|
||||||
assert str(bernoulli(10**5)) == '-5.8222943146133508236497045360612887555320691004308e+376755'
|
|
||||||
assert str(bernoulli(10**8+2)) == '1.1957035503995297272263047884604346914602088317782e+676752584'
|
|
||||||
|
|
||||||
mp.dps = 1000
|
|
||||||
assert bernoulli(10).ae(mpf(5)/66)
|
|
||||||
|
|
||||||
mp.dps = 50000
|
|
||||||
assert bernoulli(10).ae(mpf(5)/66)
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_bernpoly_eulerpoly():
|
|
||||||
mp.dps = 15
|
|
||||||
assert bernpoly(0,-1).ae(1)
|
|
||||||
assert bernpoly(0,0).ae(1)
|
|
||||||
assert bernpoly(0,'1/2').ae(1)
|
|
||||||
assert bernpoly(0,'3/4').ae(1)
|
|
||||||
assert bernpoly(0,1).ae(1)
|
|
||||||
assert bernpoly(0,2).ae(1)
|
|
||||||
assert bernpoly(1,-1).ae('-3/2')
|
|
||||||
assert bernpoly(1,0).ae('-1/2')
|
|
||||||
assert bernpoly(1,'1/2').ae(0)
|
|
||||||
assert bernpoly(1,'3/4').ae('1/4')
|
|
||||||
assert bernpoly(1,1).ae('1/2')
|
|
||||||
assert bernpoly(1,2).ae('3/2')
|
|
||||||
assert bernpoly(2,-1).ae('13/6')
|
|
||||||
assert bernpoly(2,0).ae('1/6')
|
|
||||||
assert bernpoly(2,'1/2').ae('-1/12')
|
|
||||||
assert bernpoly(2,'3/4').ae('-1/48')
|
|
||||||
assert bernpoly(2,1).ae('1/6')
|
|
||||||
assert bernpoly(2,2).ae('13/6')
|
|
||||||
assert bernpoly(3,-1).ae(-3)
|
|
||||||
assert bernpoly(3,0).ae(0)
|
|
||||||
assert bernpoly(3,'1/2').ae(0)
|
|
||||||
assert bernpoly(3,'3/4').ae('-3/64')
|
|
||||||
assert bernpoly(3,1).ae(0)
|
|
||||||
assert bernpoly(3,2).ae(3)
|
|
||||||
assert bernpoly(4,-1).ae('119/30')
|
|
||||||
assert bernpoly(4,0).ae('-1/30')
|
|
||||||
assert bernpoly(4,'1/2').ae('7/240')
|
|
||||||
assert bernpoly(4,'3/4').ae('7/3840')
|
|
||||||
assert bernpoly(4,1).ae('-1/30')
|
|
||||||
assert bernpoly(4,2).ae('119/30')
|
|
||||||
assert bernpoly(5,-1).ae(-5)
|
|
||||||
assert bernpoly(5,0).ae(0)
|
|
||||||
assert bernpoly(5,'1/2').ae(0)
|
|
||||||
assert bernpoly(5,'3/4').ae('25/1024')
|
|
||||||
assert bernpoly(5,1).ae(0)
|
|
||||||
assert bernpoly(5,2).ae(5)
|
|
||||||
assert bernpoly(10,-1).ae('665/66')
|
|
||||||
assert bernpoly(10,0).ae('5/66')
|
|
||||||
assert bernpoly(10,'1/2').ae('-2555/33792')
|
|
||||||
assert bernpoly(10,'3/4').ae('-2555/34603008')
|
|
||||||
assert bernpoly(10,1).ae('5/66')
|
|
||||||
assert bernpoly(10,2).ae('665/66')
|
|
||||||
assert bernpoly(11,-1).ae(-11)
|
|
||||||
assert bernpoly(11,0).ae(0)
|
|
||||||
assert bernpoly(11,'1/2').ae(0)
|
|
||||||
assert bernpoly(11,'3/4').ae('-555731/4194304')
|
|
||||||
assert bernpoly(11,1).ae(0)
|
|
||||||
assert bernpoly(11,2).ae(11)
|
|
||||||
assert eulerpoly(0,-1).ae(1)
|
|
||||||
assert eulerpoly(0,0).ae(1)
|
|
||||||
assert eulerpoly(0,'1/2').ae(1)
|
|
||||||
assert eulerpoly(0,'3/4').ae(1)
|
|
||||||
assert eulerpoly(0,1).ae(1)
|
|
||||||
assert eulerpoly(0,2).ae(1)
|
|
||||||
assert eulerpoly(1,-1).ae('-3/2')
|
|
||||||
assert eulerpoly(1,0).ae('-1/2')
|
|
||||||
assert eulerpoly(1,'1/2').ae(0)
|
|
||||||
assert eulerpoly(1,'3/4').ae('1/4')
|
|
||||||
assert eulerpoly(1,1).ae('1/2')
|
|
||||||
assert eulerpoly(1,2).ae('3/2')
|
|
||||||
assert eulerpoly(2,-1).ae(2)
|
|
||||||
assert eulerpoly(2,0).ae(0)
|
|
||||||
assert eulerpoly(2,'1/2').ae('-1/4')
|
|
||||||
assert eulerpoly(2,'3/4').ae('-3/16')
|
|
||||||
assert eulerpoly(2,1).ae(0)
|
|
||||||
assert eulerpoly(2,2).ae(2)
|
|
||||||
assert eulerpoly(3,-1).ae('-9/4')
|
|
||||||
assert eulerpoly(3,0).ae('1/4')
|
|
||||||
assert eulerpoly(3,'1/2').ae(0)
|
|
||||||
assert eulerpoly(3,'3/4').ae('-11/64')
|
|
||||||
assert eulerpoly(3,1).ae('-1/4')
|
|
||||||
assert eulerpoly(3,2).ae('9/4')
|
|
||||||
assert eulerpoly(4,-1).ae(2)
|
|
||||||
assert eulerpoly(4,0).ae(0)
|
|
||||||
assert eulerpoly(4,'1/2').ae('5/16')
|
|
||||||
assert eulerpoly(4,'3/4').ae('57/256')
|
|
||||||
assert eulerpoly(4,1).ae(0)
|
|
||||||
assert eulerpoly(4,2).ae(2)
|
|
||||||
assert eulerpoly(5,-1).ae('-3/2')
|
|
||||||
assert eulerpoly(5,0).ae('-1/2')
|
|
||||||
assert eulerpoly(5,'1/2').ae(0)
|
|
||||||
assert eulerpoly(5,'3/4').ae('361/1024')
|
|
||||||
assert eulerpoly(5,1).ae('1/2')
|
|
||||||
assert eulerpoly(5,2).ae('3/2')
|
|
||||||
assert eulerpoly(10,-1).ae(2)
|
|
||||||
assert eulerpoly(10,0).ae(0)
|
|
||||||
assert eulerpoly(10,'1/2').ae('-50521/1024')
|
|
||||||
assert eulerpoly(10,'3/4').ae('-36581523/1048576')
|
|
||||||
assert eulerpoly(10,1).ae(0)
|
|
||||||
assert eulerpoly(10,2).ae(2)
|
|
||||||
assert eulerpoly(11,-1).ae('-699/4')
|
|
||||||
assert eulerpoly(11,0).ae('691/4')
|
|
||||||
assert eulerpoly(11,'1/2').ae(0)
|
|
||||||
assert eulerpoly(11,'3/4').ae('-512343611/4194304')
|
|
||||||
assert eulerpoly(11,1).ae('-691/4')
|
|
||||||
assert eulerpoly(11,2).ae('699/4')
|
|
||||||
# Potential accuracy issues
|
|
||||||
assert bernpoly(10000,10000).ae('5.8196915936323387117e+39999')
|
|
||||||
assert bernpoly(200,17.5).ae(3.8048418524583064909e244)
|
|
||||||
assert eulerpoly(200,17.5).ae(-3.7309911582655785929e275)
|
|
||||||
|
|
||||||
def test_gamma():
|
|
||||||
mp.dps = 15
|
|
||||||
assert gamma(0.25).ae(3.6256099082219083119)
|
|
||||||
assert gamma(0.0001).ae(9999.4228832316241908)
|
|
||||||
assert gamma(300).ae('1.0201917073881354535e612')
|
|
||||||
assert gamma(-0.5).ae(-3.5449077018110320546)
|
|
||||||
assert gamma(-7.43).ae(0.00026524416464197007186)
|
|
||||||
#assert gamma(Rational(1,2)) == gamma(0.5)
|
|
||||||
#assert gamma(Rational(-7,3)).ae(gamma(mpf(-7)/3))
|
|
||||||
assert gamma(1+1j).ae(0.49801566811835604271 - 0.15494982830181068512j)
|
|
||||||
assert gamma(-1+0.01j).ae(-0.422733904013474115 + 99.985883082635367436j)
|
|
||||||
assert gamma(20+30j).ae(-1453876687.5534810 + 1163777777.8031573j)
|
|
||||||
# Should always give exact factorials when they can
|
|
||||||
# be represented as mpfs under the current working precision
|
|
||||||
fact = 1
|
|
||||||
for i in range(1, 18):
|
|
||||||
assert gamma(i) == fact
|
|
||||||
fact *= i
|
|
||||||
for dps in [170, 600]:
|
|
||||||
fact = 1
|
|
||||||
mp.dps = dps
|
|
||||||
for i in range(1, 105):
|
|
||||||
assert gamma(i) == fact
|
|
||||||
fact *= i
|
|
||||||
mp.dps = 100
|
|
||||||
assert gamma(0.5).ae(sqrt(pi))
|
|
||||||
mp.dps = 15
|
|
||||||
assert factorial(0) == fac(0) == 1
|
|
||||||
assert factorial(3) == 6
|
|
||||||
assert isnan(gamma(nan))
|
|
||||||
assert gamma(1100).ae('4.8579168073569433667e2866')
|
|
||||||
|
|
||||||
def test_fac2():
|
|
||||||
mp.dps = 15
|
|
||||||
assert [fac2(n) for n in range(10)] == [1,1,2,3,8,15,48,105,384,945]
|
|
||||||
assert fac2(-5).ae(1./3)
|
|
||||||
assert fac2(-11).ae(-1./945)
|
|
||||||
assert fac2(50).ae(5.20469842636666623e32)
|
|
||||||
assert fac2(0.5+0.75j).ae(0.81546769394688069176-0.34901016085573266889j)
|
|
||||||
assert fac2(inf) == inf
|
|
||||||
assert isnan(fac2(-inf))
|
|
||||||
|
|
||||||
def test_gamma_quotients():
|
|
||||||
mp.dps = 15
|
|
||||||
h = 1e-8
|
|
||||||
ep = 1e-4
|
|
||||||
G = gamma
|
|
||||||
assert gammaprod([-1],[-3,-4]) == 0
|
|
||||||
assert gammaprod([-1,0],[-5]) == inf
|
|
||||||
assert abs(gammaprod([-1],[-2]) - G(-1+h)/G(-2+h)) < 1e-4
|
|
||||||
assert abs(gammaprod([-4,-3],[-2,0]) - G(-4+h)*G(-3+h)/G(-2+h)/G(0+h)) < 1e-4
|
|
||||||
assert rf(3,0) == 1
|
|
||||||
assert rf(2.5,1) == 2.5
|
|
||||||
assert rf(-5,2) == 20
|
|
||||||
assert rf(j,j).ae(gamma(2*j)/gamma(j))
|
|
||||||
assert ff(-2,0) == 1
|
|
||||||
assert ff(-2,1) == -2
|
|
||||||
assert ff(4,3) == 24
|
|
||||||
assert ff(3,4) == 0
|
|
||||||
assert binomial(0,0) == 1
|
|
||||||
assert binomial(1,0) == 1
|
|
||||||
assert binomial(0,-1) == 0
|
|
||||||
assert binomial(3,2) == 3
|
|
||||||
assert binomial(5,2) == 10
|
|
||||||
assert binomial(5,3) == 10
|
|
||||||
assert binomial(5,5) == 1
|
|
||||||
assert binomial(-1,0) == 1
|
|
||||||
assert binomial(-2,-4) == 3
|
|
||||||
assert binomial(4.5, 1.5) == 6.5625
|
|
||||||
assert binomial(1100,1) == 1100
|
|
||||||
assert binomial(1100,2) == 604450
|
|
||||||
assert beta(1,1) == 1
|
|
||||||
assert beta(0,0) == inf
|
|
||||||
assert beta(3,0) == inf
|
|
||||||
assert beta(-1,-1) == inf
|
|
||||||
assert beta(1.5,1).ae(2/3.)
|
|
||||||
assert beta(1.5,2.5).ae(pi/16)
|
|
||||||
assert (10**15*beta(10,100)).ae(2.3455339739604649879)
|
|
||||||
assert beta(inf,inf) == 0
|
|
||||||
assert isnan(beta(-inf,inf))
|
|
||||||
assert isnan(beta(-3,inf))
|
|
||||||
assert isnan(beta(0,inf))
|
|
||||||
assert beta(inf,0.5) == beta(0.5,inf) == 0
|
|
||||||
assert beta(inf,-1.5) == inf
|
|
||||||
assert beta(inf,-0.5) == -inf
|
|
||||||
assert beta(1+2j,-1-j/2).ae(1.16396542451069943086+0.08511695947832914640j)
|
|
||||||
assert beta(-0.5,0.5) == 0
|
|
||||||
assert beta(-3,3).ae(-1/3.)
|
|
||||||
|
|
||||||
def test_zeta():
|
|
||||||
mp.dps = 15
|
|
||||||
assert zeta(2).ae(pi**2 / 6)
|
|
||||||
assert zeta(2.0).ae(pi**2 / 6)
|
|
||||||
assert zeta(mpc(2)).ae(pi**2 / 6)
|
|
||||||
assert zeta(100).ae(1)
|
|
||||||
assert zeta(0).ae(-0.5)
|
|
||||||
assert zeta(0.5).ae(-1.46035450880958681)
|
|
||||||
assert zeta(-1).ae(-mpf(1)/12)
|
|
||||||
assert zeta(-2) == 0
|
|
||||||
assert zeta(-3).ae(mpf(1)/120)
|
|
||||||
assert zeta(-4) == 0
|
|
||||||
assert zeta(-100) == 0
|
|
||||||
assert isnan(zeta(nan))
|
|
||||||
# Zeros in the critical strip
|
|
||||||
assert zeta(mpc(0.5, 14.1347251417346937904)).ae(0)
|
|
||||||
assert zeta(mpc(0.5, 21.0220396387715549926)).ae(0)
|
|
||||||
assert zeta(mpc(0.5, 25.0108575801456887632)).ae(0)
|
|
||||||
mp.dps = 50
|
|
||||||
im = '236.5242296658162058024755079556629786895294952121891237'
|
|
||||||
assert zeta(mpc(0.5, im)).ae(0, 1e-46)
|
|
||||||
mp.dps = 15
|
|
||||||
# Complex reflection formula
|
|
||||||
assert (zeta(-60+3j) / 10**34).ae(8.6270183987866146+15.337398548226238j)
|
|
||||||
|
|
||||||
def test_altzeta():
|
|
||||||
mp.dps = 15
|
|
||||||
assert altzeta(-2) == 0
|
|
||||||
assert altzeta(-4) == 0
|
|
||||||
assert altzeta(-100) == 0
|
|
||||||
assert altzeta(0) == 0.5
|
|
||||||
assert altzeta(-1) == 0.25
|
|
||||||
assert altzeta(-3) == -0.125
|
|
||||||
assert altzeta(-5) == 0.25
|
|
||||||
assert altzeta(-21) == 1180529130.25
|
|
||||||
assert altzeta(1).ae(log(2))
|
|
||||||
assert altzeta(2).ae(pi**2/12)
|
|
||||||
assert altzeta(10).ae(73*pi**10/6842880)
|
|
||||||
assert altzeta(50) < 1
|
|
||||||
assert altzeta(60, rounding='d') < 1
|
|
||||||
assert altzeta(60, rounding='u') == 1
|
|
||||||
assert altzeta(10000, rounding='d') < 1
|
|
||||||
assert altzeta(10000, rounding='u') == 1
|
|
||||||
assert altzeta(3+0j) == altzeta(3)
|
|
||||||
s = 3+4j
|
|
||||||
assert altzeta(s).ae((1-2**(1-s))*zeta(s))
|
|
||||||
s = -3+4j
|
|
||||||
assert altzeta(s).ae((1-2**(1-s))*zeta(s))
|
|
||||||
assert altzeta(-100.5).ae(4.58595480083585913e+108)
|
|
||||||
assert altzeta(1.3).ae(0.73821404216623045)
|
|
||||||
|
|
||||||
def test_zeta_huge():
|
|
||||||
mp.dps = 15
|
|
||||||
assert zeta(inf) == 1
|
|
||||||
mp.dps = 50
|
|
||||||
assert zeta(100).ae('1.0000000000000000000000000000007888609052210118073522')
|
|
||||||
assert zeta(40*pi).ae('1.0000000000000000000000000000000000000148407238666182')
|
|
||||||
mp.dps = 10000
|
|
||||||
v = zeta(33000)
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(v-1) == '1.02363019598118e-9934'
|
|
||||||
assert zeta(pi*1000, rounding=round_up) > 1
|
|
||||||
assert zeta(3000, rounding=round_up) > 1
|
|
||||||
assert zeta(pi*1000) == 1
|
|
||||||
assert zeta(3000) == 1
|
|
||||||
|
|
||||||
def test_zeta_negative():
|
|
||||||
mp.dps = 150
|
|
||||||
a = -pi*10**40
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(zeta(a)) == '2.55880492708712e+1233536161668617575553892558646631323374078'
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(zeta(a)) == '2.5588049270871154960875033337384432038436330847333e+1233536161668617575553892558646631323374078'
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_polygamma():
|
|
||||||
mp.dps = 15
|
|
||||||
psi0 = lambda z: psi(0,z)
|
|
||||||
psi1 = lambda z: psi(1,z)
|
|
||||||
assert psi0(3) == psi(0,3) == digamma(3)
|
|
||||||
#assert psi2(3) == psi(2,3) == tetragamma(3)
|
|
||||||
#assert psi3(3) == psi(3,3) == pentagamma(3)
|
|
||||||
assert psi0(pi).ae(0.97721330794200673)
|
|
||||||
assert psi0(-pi).ae(7.8859523853854902)
|
|
||||||
assert psi0(-pi+1).ae(7.5676424992016996)
|
|
||||||
assert psi0(pi+j).ae(1.04224048313859376 + 0.35853686544063749j)
|
|
||||||
assert psi0(-pi-j).ae(1.3404026194821986 - 2.8824392476809402j)
|
|
||||||
assert findroot(psi0, 1).ae(1.4616321449683622)
|
|
||||||
assert psi0(inf) == inf
|
|
||||||
assert psi1(inf) == 0
|
|
||||||
assert psi(2,inf) == 0
|
|
||||||
assert psi1(pi).ae(0.37424376965420049)
|
|
||||||
assert psi1(-pi).ae(53.030438740085385)
|
|
||||||
assert psi1(pi+j).ae(0.32935710377142464 - 0.12222163911221135j)
|
|
||||||
assert psi1(-pi-j).ae(-0.30065008356019703 + 0.01149892486928227j)
|
|
||||||
assert (10**6*psi(4,1+10*pi*j)).ae(-6.1491803479004446 - 0.3921316371664063j)
|
|
||||||
assert psi0(1+10*pi*j).ae(3.4473994217222650 + 1.5548808324857071j)
|
|
||||||
assert isnan(psi0(nan))
|
|
||||||
assert isnan(psi0(-inf))
|
|
||||||
assert psi0(-100.5).ae(4.615124601338064)
|
|
||||||
assert psi0(3+0j).ae(psi0(3))
|
|
||||||
assert psi0(-100+3j).ae(4.6106071768714086321+3.1117510556817394626j)
|
|
||||||
|
|
||||||
def test_polygamma_high_prec():
|
|
||||||
mp.dps = 100
|
|
||||||
assert str(psi(0,pi)) == "0.9772133079420067332920694864061823436408346099943256380095232865318105924777141317302075654362928734"
|
|
||||||
assert str(psi(10,pi)) == "-12.98876181434889529310283769414222588307175962213707170773803550518307617769657562747174101900659238"
|
|
||||||
|
|
||||||
def test_polygamma_identities():
|
|
||||||
mp.dps = 15
|
|
||||||
psi0 = lambda z: psi(0,z)
|
|
||||||
psi1 = lambda z: psi(1,z)
|
|
||||||
psi2 = lambda z: psi(2,z)
|
|
||||||
assert psi0(0.5).ae(-euler-2*log(2))
|
|
||||||
assert psi0(1).ae(-euler)
|
|
||||||
assert psi1(0.5).ae(0.5*pi**2)
|
|
||||||
assert psi1(1).ae(pi**2/6)
|
|
||||||
assert psi1(0.25).ae(pi**2 + 8*catalan)
|
|
||||||
assert psi2(1).ae(-2*apery)
|
|
||||||
mp.dps = 20
|
|
||||||
u = -182*apery+4*sqrt(3)*pi**3
|
|
||||||
mp.dps = 15
|
|
||||||
assert psi(2,5/6.).ae(u)
|
|
||||||
assert psi(3,0.5).ae(pi**4)
|
|
||||||
|
|
||||||
def test_foxtrot_identity():
|
|
||||||
# A test of the complex digamma function.
|
|
||||||
# See http://mathworld.wolfram.com/FoxTrotSeries.html and
|
|
||||||
# http://mathworld.wolfram.com/DigammaFunction.html
|
|
||||||
psi0 = lambda z: psi(0,z)
|
|
||||||
mp.dps = 50
|
|
||||||
a = (-1)**fraction(1,3)
|
|
||||||
b = (-1)**fraction(2,3)
|
|
||||||
x = -psi0(0.5*a) - psi0(-0.5*b) + psi0(0.5*(1+a)) + psi0(0.5*(1-b))
|
|
||||||
y = 2*pi*sech(0.5*sqrt(3)*pi)
|
|
||||||
assert x.ae(y)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_polygamma_high_order():
|
|
||||||
mp.dps = 100
|
|
||||||
assert str(psi(50, pi)) == "-1344100348958402765749252447726432491812.641985273160531055707095989227897753035823152397679626136483"
|
|
||||||
assert str(psi(50, pi + 14*e)) == "-0.00000000000000000189793739550804321623512073101895801993019919886375952881053090844591920308111549337295143780341396"
|
|
||||||
assert str(psi(50, pi + 14*e*j)) == ("(-0.0000000000000000522516941152169248975225472155683565752375889510631513244785"
|
|
||||||
"9377385233700094871256507814151956624433 - 0.00000000000000001813157041407010184"
|
|
||||||
"702414110218205348527862196327980417757665282244728963891298080199341480881811613j)")
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(psi(50, pi)) == "-1.34410034895841e+39"
|
|
||||||
assert str(psi(50, pi + 14*e)) == "-1.89793739550804e-18"
|
|
||||||
assert str(psi(50, pi + 14*e*j)) == "(-5.2251694115217e-17 - 1.81315704140701e-17j)"
|
|
||||||
|
|
||||||
def test_harmonic():
|
|
||||||
mp.dps = 15
|
|
||||||
assert harmonic(0) == 0
|
|
||||||
assert harmonic(1) == 1
|
|
||||||
assert harmonic(2) == 1.5
|
|
||||||
assert harmonic(3).ae(1. + 1./2 + 1./3)
|
|
||||||
assert harmonic(10**10).ae(23.603066594891989701)
|
|
||||||
assert harmonic(10**1000).ae(2303.162308658947)
|
|
||||||
assert harmonic(0.5).ae(2-2*log(2))
|
|
||||||
assert harmonic(inf) == inf
|
|
||||||
assert harmonic(2+0j) == 1.5+0j
|
|
||||||
assert harmonic(1+2j).ae(1.4918071802755104+0.92080728264223022j)
|
|
||||||
|
|
||||||
def test_gamma_huge_1():
|
|
||||||
mp.dps = 500
|
|
||||||
x = mpf(10**10) / 7
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(gamma(x)) == "6.26075321389519e+12458010678"
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(gamma(x)) == "6.2607532138951929201303779291707455874010420783933e+12458010678"
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_gamma_huge_2():
|
|
||||||
mp.dps = 500
|
|
||||||
x = mpf(10**100) / 19
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(gamma(x)) == (\
|
|
||||||
"1.82341134776679e+5172997469323364168990133558175077136829182824042201886051511"
|
|
||||||
"9656908623426021308685461258226190190661")
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(gamma(x)) == (\
|
|
||||||
"1.82341134776678875374414910350027596939980412984e+5172997469323364168990133558"
|
|
||||||
"1750771368291828240422018860515119656908623426021308685461258226190190661")
|
|
||||||
|
|
||||||
def test_gamma_huge_3():
|
|
||||||
mp.dps = 500
|
|
||||||
x = 10**80 // 3 + 10**70*j / 7
|
|
||||||
mp.dps = 15
|
|
||||||
y = gamma(x)
|
|
||||||
assert str(y.real) == (\
|
|
||||||
"-6.82925203918106e+2636286142112569524501781477865238132302397236429627932441916"
|
|
||||||
"056964386399485392600")
|
|
||||||
assert str(y.imag) == (\
|
|
||||||
"8.54647143678418e+26362861421125695245017814778652381323023972364296279324419160"
|
|
||||||
"56964386399485392600")
|
|
||||||
mp.dps = 50
|
|
||||||
y = gamma(x)
|
|
||||||
assert str(y.real) == (\
|
|
||||||
"-6.8292520391810548460682736226799637356016538421817e+26362861421125695245017814"
|
|
||||||
"77865238132302397236429627932441916056964386399485392600")
|
|
||||||
assert str(y.imag) == (\
|
|
||||||
"8.5464714367841748507479306948130687511711420234015e+263628614211256952450178147"
|
|
||||||
"7865238132302397236429627932441916056964386399485392600")
|
|
||||||
|
|
||||||
def test_gamma_huge_4():
|
|
||||||
x = 3200+11500j
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(gamma(x)) == \
|
|
||||||
"(8.95783268539713e+5164 - 1.94678798329735e+5164j)"
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(gamma(x)) == (\
|
|
||||||
"(8.9578326853971339570292952697675570822206567327092e+5164"
|
|
||||||
" - 1.9467879832973509568895402139429643650329524144794e+51"
|
|
||||||
"64j)")
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_gamma_huge_5():
|
|
||||||
mp.dps = 500
|
|
||||||
x = 10**60 * j / 3
|
|
||||||
mp.dps = 15
|
|
||||||
y = gamma(x)
|
|
||||||
assert str(y.real) == "-3.27753899634941e-227396058973640224580963937571892628368354580620654233316839"
|
|
||||||
assert str(y.imag) == "-7.1519888950416e-227396058973640224580963937571892628368354580620654233316841"
|
|
||||||
mp.dps = 50
|
|
||||||
y = gamma(x)
|
|
||||||
assert str(y.real) == (\
|
|
||||||
"-3.2775389963494132168950056995974690946983219123935e-22739605897364022458096393"
|
|
||||||
"7571892628368354580620654233316839")
|
|
||||||
assert str(y.imag) == (\
|
|
||||||
"-7.1519888950415979749736749222530209713136588885897e-22739605897364022458096393"
|
|
||||||
"7571892628368354580620654233316841")
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
"""
|
|
||||||
XXX: fails
|
|
||||||
def test_gamma_huge_6():
|
|
||||||
return
|
|
||||||
mp.dps = 500
|
|
||||||
x = -10**10 + mpf(10)**(-175)*j
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(gamma(x)) == \
|
|
||||||
"(1.86729378905343e-95657055178 - 4.29960285282433e-95657055002j)"
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(gamma(x)) == (\
|
|
||||||
"(1.8672937890534298925763143275474177736153484820662e-9565705517"
|
|
||||||
"8 - 4.2996028528243336966001185406200082244961757496106e-9565705"
|
|
||||||
"5002j)")
|
|
||||||
mp.dps = 15
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_gamma_huge_7():
|
|
||||||
mp.dps = 100
|
|
||||||
a = 3 + j/mpf(10)**1000
|
|
||||||
mp.dps = 15
|
|
||||||
y = gamma(a)
|
|
||||||
assert str(y.real) == "2.0"
|
|
||||||
assert str(y.imag) == "2.16735365342606e-1000"
|
|
||||||
mp.dps = 50
|
|
||||||
y = gamma(a)
|
|
||||||
assert str(y.real) == "2.0"
|
|
||||||
assert str(y.imag) == "2.1673536534260596065418805612488708028522563689298e-1000"
|
|
||||||
|
|
||||||
def test_stieltjes():
|
|
||||||
mp.dps = 15
|
|
||||||
assert stieltjes(0).ae(+euler)
|
|
||||||
mp.dps = 25
|
|
||||||
assert stieltjes(1).ae('-0.07281584548367672486058637587')
|
|
||||||
assert stieltjes(2).ae('-0.009690363192872318484530386035')
|
|
||||||
assert stieltjes(3).ae('0.002053834420303345866160046543')
|
|
||||||
assert stieltjes(4).ae('0.002325370065467300057468170178')
|
|
||||||
mp.dps = 15
|
|
||||||
assert stieltjes(1).ae(-0.07281584548367672486058637587)
|
|
||||||
assert stieltjes(2).ae(-0.009690363192872318484530386035)
|
|
||||||
assert stieltjes(3).ae(0.002053834420303345866160046543)
|
|
||||||
assert stieltjes(4).ae(0.0023253700654673000574681701775)
|
|
||||||
|
|
||||||
def test_barnesg():
|
|
||||||
mp.dps = 15
|
|
||||||
assert barnesg(0) == barnesg(-1) == 0
|
|
||||||
assert [superfac(i) for i in range(8)] == [1, 1, 2, 12, 288, 34560, 24883200, 125411328000]
|
|
||||||
assert str(superfac(1000)) == '3.24570818422368e+1177245'
|
|
||||||
assert isnan(barnesg(nan))
|
|
||||||
assert isnan(superfac(nan))
|
|
||||||
assert isnan(hyperfac(nan))
|
|
||||||
assert barnesg(inf) == inf
|
|
||||||
assert superfac(inf) == inf
|
|
||||||
assert hyperfac(inf) == inf
|
|
||||||
assert isnan(superfac(-inf))
|
|
||||||
assert barnesg(0.7).ae(0.8068722730141471)
|
|
||||||
assert barnesg(2+3j).ae(-0.17810213864082169+0.04504542715447838j)
|
|
||||||
assert [hyperfac(n) for n in range(7)] == [1, 1, 4, 108, 27648, 86400000, 4031078400000]
|
|
||||||
assert [hyperfac(n) for n in range(0,-7,-1)] == [1,1,-1,-4,108,27648,-86400000]
|
|
||||||
a = barnesg(-3+0j)
|
|
||||||
assert a == 0 and isinstance(a, mpc)
|
|
||||||
a = hyperfac(-3+0j)
|
|
||||||
assert a == -4 and isinstance(a, mpc)
|
|
||||||
|
|
||||||
def test_polylog():
|
|
||||||
mp.dps = 15
|
|
||||||
zs = [mpmathify(z) for z in [0, 0.5, 0.99, 4, -0.5, -4, 1j, 3+4j]]
|
|
||||||
for z in zs: assert polylog(1, z).ae(-log(1-z))
|
|
||||||
for z in zs: assert polylog(0, z).ae(z/(1-z))
|
|
||||||
for z in zs: assert polylog(-1, z).ae(z/(1-z)**2)
|
|
||||||
for z in zs: assert polylog(-2, z).ae(z*(1+z)/(1-z)**3)
|
|
||||||
for z in zs: assert polylog(-3, z).ae(z*(1+4*z+z**2)/(1-z)**4)
|
|
||||||
assert polylog(3, 7).ae(5.3192579921456754382-5.9479244480803301023j)
|
|
||||||
assert polylog(3, -7).ae(-4.5693548977219423182)
|
|
||||||
assert polylog(2, 0.9).ae(1.2997147230049587252)
|
|
||||||
assert polylog(2, -0.9).ae(-0.75216317921726162037)
|
|
||||||
assert polylog(2, 0.9j).ae(-0.17177943786580149299+0.83598828572550503226j)
|
|
||||||
assert polylog(2, 1.1).ae(1.9619991013055685931-0.2994257606855892575j)
|
|
||||||
assert polylog(2, -1.1).ae(-0.89083809026228260587)
|
|
||||||
assert polylog(2, 1.1*sqrt(j)).ae(0.58841571107611387722+1.09962542118827026011j)
|
|
||||||
assert polylog(-2, 0.9).ae(1710)
|
|
||||||
assert polylog(-2, -0.9).ae(-90/6859.)
|
|
||||||
assert polylog(3, 0.9).ae(1.0496589501864398696)
|
|
||||||
assert polylog(-3, 0.9).ae(48690)
|
|
||||||
assert polylog(-3, -4).ae(-0.0064)
|
|
||||||
assert polylog(0.5+j/3, 0.5+j/2).ae(0.31739144796565650535 + 0.99255390416556261437j)
|
|
||||||
assert polylog(3+4j,1).ae(zeta(3+4j))
|
|
||||||
assert polylog(3+4j,-1).ae(-altzeta(3+4j))
|
|
||||||
|
|
||||||
def test_bell_polyexp():
|
|
||||||
mp.dps = 15
|
|
||||||
# TODO: more tests for polyexp
|
|
||||||
assert (polyexp(0,1e-10)*10**10).ae(1.00000000005)
|
|
||||||
assert (polyexp(1,1e-10)*10**10).ae(1.0000000001)
|
|
||||||
assert polyexp(5,3j).ae(-607.7044517476176454+519.962786482001476087j)
|
|
||||||
assert polyexp(-1,3.5).ae(12.09537536175543444)
|
|
||||||
# bell(0,x) = 1
|
|
||||||
assert bell(0,0) == 1
|
|
||||||
assert bell(0,1) == 1
|
|
||||||
assert bell(0,2) == 1
|
|
||||||
assert bell(0,inf) == 1
|
|
||||||
assert bell(0,-inf) == 1
|
|
||||||
assert isnan(bell(0,nan))
|
|
||||||
# bell(1,x) = x
|
|
||||||
assert bell(1,4) == 4
|
|
||||||
assert bell(1,0) == 0
|
|
||||||
assert bell(1,inf) == inf
|
|
||||||
assert bell(1,-inf) == -inf
|
|
||||||
assert isnan(bell(1,nan))
|
|
||||||
# bell(2,x) = x*(1+x)
|
|
||||||
assert bell(2,-1) == 0
|
|
||||||
assert bell(2,0) == 0
|
|
||||||
# large orders / arguments
|
|
||||||
assert bell(10) == 115975
|
|
||||||
assert bell(10,1) == 115975
|
|
||||||
assert bell(10, -8) == 11054008
|
|
||||||
assert bell(5,-50) == -253087550
|
|
||||||
assert bell(50,-50).ae('3.4746902914629720259e74')
|
|
||||||
mp.dps = 80
|
|
||||||
assert bell(50,-50) == 347469029146297202586097646631767227177164818163463279814268368579055777450
|
|
||||||
assert bell(40,50) == 5575520134721105844739265207408344706846955281965031698187656176321717550
|
|
||||||
assert bell(74) == 5006908024247925379707076470957722220463116781409659160159536981161298714301202
|
|
||||||
mp.dps = 15
|
|
||||||
assert bell(10,20j) == 7504528595600+15649605360020j
|
|
||||||
# continuity of the generalization
|
|
||||||
assert bell(0.5,0).ae(sinc(pi*0.5))
|
|
||||||
|
|
||||||
def test_primezeta():
|
|
||||||
mp.dps = 15
|
|
||||||
assert primezeta(0.9).ae(1.8388316154446882243 + 3.1415926535897932385j)
|
|
||||||
assert primezeta(4).ae(0.076993139764246844943)
|
|
||||||
assert primezeta(1) == inf
|
|
||||||
assert primezeta(inf) == 0
|
|
||||||
assert isnan(primezeta(nan))
|
|
||||||
|
|
||||||
def test_rs_zeta():
|
|
||||||
mp.dps = 15
|
|
||||||
assert zeta(0.5+100000j).ae(1.0730320148577531321 + 5.7808485443635039843j)
|
|
||||||
assert zeta(0.75+100000j).ae(1.837852337251873704 + 1.9988492668661145358j)
|
|
||||||
assert zeta(0.5+1000000j, derivative=3).ae(1647.7744105852674733 - 1423.1270943036622097j)
|
|
||||||
assert zeta(1+1000000j, derivative=3).ae(3.4085866124523582894 - 18.179184721525947301j)
|
|
||||||
assert zeta(1+1000000j, derivative=1).ae(-0.10423479366985452134 - 0.74728992803359056244j)
|
|
||||||
assert zeta(0.5-1000000j, derivative=1).ae(11.636804066002521459 + 17.127254072212996004j)
|
|
||||||
# Additional sanity tests using fp arithmetic.
|
|
||||||
# Some more high-precision tests are found in the docstrings
|
|
||||||
def ae(x, y, tol=1e-6):
|
|
||||||
return abs(x-y) < tol*abs(y)
|
|
||||||
assert ae(fp.zeta(0.5-100000j), 1.0730320148577531321 - 5.7808485443635039843j)
|
|
||||||
assert ae(fp.zeta(0.75-100000j), 1.837852337251873704 - 1.9988492668661145358j)
|
|
||||||
assert ae(fp.zeta(0.5+1e6j), 0.076089069738227100006 + 2.8051021010192989554j)
|
|
||||||
assert ae(fp.zeta(0.5+1e6j, derivative=1), 11.636804066002521459 - 17.127254072212996004j)
|
|
||||||
assert ae(fp.zeta(1+1e6j), 0.94738726251047891048 + 0.59421999312091832833j)
|
|
||||||
assert ae(fp.zeta(1+1e6j, derivative=1), -0.10423479366985452134 - 0.74728992803359056244j)
|
|
||||||
assert ae(fp.zeta(0.5+100000j, derivative=1), 10.766962036817482375 - 30.92705282105996714j)
|
|
||||||
assert ae(fp.zeta(0.5+100000j, derivative=2), -119.40515625740538429 + 217.14780631141830251j)
|
|
||||||
assert ae(fp.zeta(0.5+100000j, derivative=3), 1129.7550282628460881 - 1685.4736895169690346j)
|
|
||||||
assert ae(fp.zeta(0.5+100000j, derivative=4), -10407.160819314958615 + 13777.786698628045085j)
|
|
||||||
assert ae(fp.zeta(0.75+100000j, derivative=1), -0.41742276699594321475 - 6.4453816275049955949j)
|
|
||||||
assert ae(fp.zeta(0.75+100000j, derivative=2), -9.214314279161977266 + 35.07290795337967899j)
|
|
||||||
assert ae(fp.zeta(0.75+100000j, derivative=3), 110.61331857820103469 - 236.87847130518129926j)
|
|
||||||
assert ae(fp.zeta(0.75+100000j, derivative=4), -1054.334275898559401 + 1769.9177890161596383j)
|
|
||||||
|
|
||||||
def test_zeta_near_1():
|
|
||||||
# Test for a former bug in mpf_zeta and mpc_zeta
|
|
||||||
mp.dps = 15
|
|
||||||
s1 = fadd(1, '1e-10', exact=True)
|
|
||||||
s2 = fadd(1, '-1e-10', exact=True)
|
|
||||||
s3 = fadd(1, '1e-10j', exact=True)
|
|
||||||
assert zeta(s1).ae(1.000000000057721566490881444e10)
|
|
||||||
assert zeta(s2).ae(-9.99999999942278433510574872e9)
|
|
||||||
z = zeta(s3)
|
|
||||||
assert z.real.ae(0.57721566490153286060)
|
|
||||||
assert z.imag.ae(-9.9999999999999999999927184e9)
|
|
||||||
mp.dps = 30
|
|
||||||
s1 = fadd(1, '1e-50', exact=True)
|
|
||||||
s2 = fadd(1, '-1e-50', exact=True)
|
|
||||||
s3 = fadd(1, '1e-50j', exact=True)
|
|
||||||
assert zeta(s1).ae('1e50')
|
|
||||||
assert zeta(s2).ae('-1e50')
|
|
||||||
z = zeta(s3)
|
|
||||||
assert z.real.ae('0.57721566490153286060651209008240243104215933593992')
|
|
||||||
assert z.imag.ae('-1e50')
|
|
||||||
|
|
@ -1,292 +0,0 @@
|
||||||
"""
|
|
||||||
Check that the output from irrational functions is accurate for
|
|
||||||
high-precision input, from 5 to 200 digits. The reference values were
|
|
||||||
verified with Mathematica.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
precs = [5, 15, 28, 35, 57, 80, 100, 150, 200]
|
|
||||||
|
|
||||||
# sqrt(3) + pi/2
|
|
||||||
a = \
|
|
||||||
"3.302847134363773912758768033145623809041389953497933538543279275605"\
|
|
||||||
"841220051904536395163599428307109666700184672047856353516867399774243594"\
|
|
||||||
"67433521615861420725323528325327484262075464241255915238845599752675"
|
|
||||||
|
|
||||||
# e + 1/euler**2
|
|
||||||
b = \
|
|
||||||
"5.719681166601007617111261398629939965860873957353320734275716220045750"\
|
|
||||||
"31474116300529519620938123730851145473473708966080207482581266469342214"\
|
|
||||||
"824842256999042984813905047895479210702109260221361437411947323431"
|
|
||||||
|
|
||||||
# sqrt(a)
|
|
||||||
sqrt_a = \
|
|
||||||
"1.817373691447021556327498239690365674922395036495564333152483422755"\
|
|
||||||
"144321726165582817927383239308173567921345318453306994746434073691275094"\
|
|
||||||
"484777905906961689902608644112196725896908619756404253109722911487"
|
|
||||||
|
|
||||||
# sqrt(a+b*i).real
|
|
||||||
sqrt_abi_real = \
|
|
||||||
"2.225720098415113027729407777066107959851146508557282707197601407276"\
|
|
||||||
"89160998185797504198062911768240808839104987021515555650875977724230130"\
|
|
||||||
"3584116233925658621288393930286871862273400475179312570274423840384"
|
|
||||||
|
|
||||||
# sqrt(a+b*i).imag
|
|
||||||
sqrt_abi_imag = \
|
|
||||||
"1.2849057639084690902371581529110949983261182430040898147672052833653668"\
|
|
||||||
"0629534491275114877090834296831373498336559849050755848611854282001250"\
|
|
||||||
"1924311019152914021365263161630765255610885489295778894976075186"
|
|
||||||
|
|
||||||
# log(a)
|
|
||||||
log_a = \
|
|
||||||
"1.194784864491089550288313512105715261520511949410072046160598707069"\
|
|
||||||
"4336653155025770546309137440687056366757650909754708302115204338077595203"\
|
|
||||||
"83005773986664564927027147084436553262269459110211221152925732612"
|
|
||||||
|
|
||||||
# log(a+b*i).real
|
|
||||||
log_abi_real = \
|
|
||||||
"1.8877985921697018111624077550443297276844736840853590212962006811663"\
|
|
||||||
"04949387789489704203167470111267581371396245317618589339274243008242708"\
|
|
||||||
"014251531496104028712866224020066439049377679709216784954509456421"
|
|
||||||
|
|
||||||
# log(a+b*i).imag
|
|
||||||
log_abi_imag = \
|
|
||||||
"1.0471204952840802663567714297078763189256357109769672185219334169734948"\
|
|
||||||
"4265809854092437285294686651806426649541504240470168212723133326542181"\
|
|
||||||
"8300136462287639956713914482701017346851009323172531601894918640"
|
|
||||||
|
|
||||||
# exp(a)
|
|
||||||
exp_a = \
|
|
||||||
"27.18994224087168661137253262213293847994194869430518354305430976149"\
|
|
||||||
"382792035050358791398632888885200049857986258414049540376323785711941636"\
|
|
||||||
"100358982497583832083513086941635049329804685212200507288797531143"
|
|
||||||
|
|
||||||
# exp(a+b*i).real
|
|
||||||
exp_abi_real = \
|
|
||||||
"22.98606617170543596386921087657586890620262522816912505151109385026"\
|
|
||||||
"40160179326569526152851983847133513990281518417211964710397233157168852"\
|
|
||||||
"4963130831190142571659948419307628119985383887599493378056639916701"
|
|
||||||
|
|
||||||
# exp(a+b*i).imag
|
|
||||||
exp_abi_imag = \
|
|
||||||
"-14.523557450291489727214750571590272774669907424478129280902375851196283"\
|
|
||||||
"3377162379031724734050088565710975758824441845278120105728824497308303"\
|
|
||||||
"6065619788140201636218705414429933685889542661364184694108251449"
|
|
||||||
|
|
||||||
# a**b
|
|
||||||
pow_a_b = \
|
|
||||||
"928.7025342285568142947391505837660251004990092821305668257284426997"\
|
|
||||||
"361966028275685583421197860603126498884545336686124793155581311527995550"\
|
|
||||||
"580229264427202446131740932666832138634013168125809402143796691154"
|
|
||||||
|
|
||||||
# (a**(a+b*i)).real
|
|
||||||
pow_a_abi_real = \
|
|
||||||
"44.09156071394489511956058111704382592976814280267142206420038656267"\
|
|
||||||
"67707916510652790502399193109819563864568986234654864462095231138500505"\
|
|
||||||
"8197456514795059492120303477512711977915544927440682508821426093455"
|
|
||||||
|
|
||||||
# (a**(a+b*i)).imag
|
|
||||||
pow_a_abi_imag = \
|
|
||||||
"27.069371511573224750478105146737852141664955461266218367212527612279886"\
|
|
||||||
"9322304536553254659049205414427707675802193810711302947536332040474573"\
|
|
||||||
"8166261217563960235014674118610092944307893857862518964990092301"
|
|
||||||
|
|
||||||
# ((a+b*i)**(a+b*i)).real
|
|
||||||
pow_abi_abi_real = \
|
|
||||||
"-0.15171310677859590091001057734676423076527145052787388589334350524"\
|
|
||||||
"8084195882019497779202452975350579073716811284169068082670778986235179"\
|
|
||||||
"0813026562962084477640470612184016755250592698408112493759742219150452"\
|
|
||||||
|
|
||||||
# ((a+b*i)**(a+b*i)).imag
|
|
||||||
pow_abi_abi_imag = \
|
|
||||||
"1.2697592504953448936553147870155987153192995316950583150964099070426"\
|
|
||||||
"4736837932577176947632535475040521749162383347758827307504526525647759"\
|
|
||||||
"97547638617201824468382194146854367480471892602963428122896045019902"
|
|
||||||
|
|
||||||
# sin(a)
|
|
||||||
sin_a = \
|
|
||||||
"-0.16055653857469062740274792907968048154164433772938156243509084009"\
|
|
||||||
"38437090841460493108570147191289893388608611542655654723437248152535114"\
|
|
||||||
"528368009465836614227575701220612124204622383149391870684288862269631"
|
|
||||||
|
|
||||||
# sin(1000*a)
|
|
||||||
sin_1000a = \
|
|
||||||
"-0.85897040577443833776358106803777589664322997794126153477060795801"\
|
|
||||||
"09151695416961724733492511852267067419573754315098042850381158563024337"\
|
|
||||||
"216458577140500488715469780315833217177634490142748614625281171216863"
|
|
||||||
|
|
||||||
# sin(a+b*i)
|
|
||||||
sin_abi_real = \
|
|
||||||
"-24.4696999681556977743346798696005278716053366404081910969773939630"\
|
|
||||||
"7149215135459794473448465734589287491880563183624997435193637389884206"\
|
|
||||||
"02151395451271809790360963144464736839412254746645151672423256977064"
|
|
||||||
|
|
||||||
sin_abi_imag = \
|
|
||||||
"-150.42505378241784671801405965872972765595073690984080160750785565810981"\
|
|
||||||
"8314482499135443827055399655645954830931316357243750839088113122816583"\
|
|
||||||
"7169201254329464271121058839499197583056427233866320456505060735"
|
|
||||||
|
|
||||||
# cos
|
|
||||||
cos_a = \
|
|
||||||
"-0.98702664499035378399332439243967038895709261414476495730788864004"\
|
|
||||||
"05406821549361039745258003422386169330787395654908532996287293003581554"\
|
|
||||||
"257037193284199198069707141161341820684198547572456183525659969145501"
|
|
||||||
|
|
||||||
cos_1000a = \
|
|
||||||
"-0.51202523570982001856195696460663971099692261342827540426136215533"\
|
|
||||||
"52686662667660613179619804463250686852463876088694806607652218586060613"\
|
|
||||||
"951310588158830695735537073667299449753951774916401887657320950496820"
|
|
||||||
|
|
||||||
# tan
|
|
||||||
tan_a = \
|
|
||||||
"0.162666873675188117341401059858835168007137819495998960250142156848"\
|
|
||||||
"639654718809412181543343168174807985559916643549174530459883826451064966"\
|
|
||||||
"7996119428949951351938178809444268785629011625179962457123195557310"
|
|
||||||
|
|
||||||
tan_abi_real = \
|
|
||||||
"6.822696615947538488826586186310162599974827139564433912601918442911"\
|
|
||||||
"1026830824380070400102213741875804368044342309515353631134074491271890"\
|
|
||||||
"467615882710035471686578162073677173148647065131872116479947620E-6"
|
|
||||||
|
|
||||||
tan_abi_imag = \
|
|
||||||
"0.9999795833048243692245661011298447587046967777739649018690797625964167"\
|
|
||||||
"1446419978852235960862841608081413169601038230073129482874832053357571"\
|
|
||||||
"62702259309150715669026865777947502665936317953101462202542168429"
|
|
||||||
|
|
||||||
|
|
||||||
def test_hp():
|
|
||||||
for dps in precs:
|
|
||||||
mp.dps = dps + 8
|
|
||||||
aa = mpf(a)
|
|
||||||
bb = mpf(b)
|
|
||||||
a1000 = 1000*mpf(a)
|
|
||||||
abi = mpc(aa, bb)
|
|
||||||
mp.dps = dps
|
|
||||||
assert (sqrt(3) + pi/2).ae(aa)
|
|
||||||
assert (e + 1/euler**2).ae(bb)
|
|
||||||
|
|
||||||
assert sqrt(aa).ae(mpf(sqrt_a))
|
|
||||||
assert sqrt(abi).ae(mpc(sqrt_abi_real, sqrt_abi_imag))
|
|
||||||
|
|
||||||
assert log(aa).ae(mpf(log_a))
|
|
||||||
assert log(abi).ae(mpc(log_abi_real, log_abi_imag))
|
|
||||||
|
|
||||||
assert exp(aa).ae(mpf(exp_a))
|
|
||||||
assert exp(abi).ae(mpc(exp_abi_real, exp_abi_imag))
|
|
||||||
|
|
||||||
assert (aa**bb).ae(mpf(pow_a_b))
|
|
||||||
assert (aa**abi).ae(mpc(pow_a_abi_real, pow_a_abi_imag))
|
|
||||||
assert (abi**abi).ae(mpc(pow_abi_abi_real, pow_abi_abi_imag))
|
|
||||||
|
|
||||||
assert sin(a).ae(mpf(sin_a))
|
|
||||||
assert sin(a1000).ae(mpf(sin_1000a))
|
|
||||||
assert sin(abi).ae(mpc(sin_abi_real, sin_abi_imag))
|
|
||||||
|
|
||||||
assert cos(a).ae(mpf(cos_a))
|
|
||||||
assert cos(a1000).ae(mpf(cos_1000a))
|
|
||||||
|
|
||||||
assert tan(a).ae(mpf(tan_a))
|
|
||||||
assert tan(abi).ae(mpc(tan_abi_real, tan_abi_imag))
|
|
||||||
|
|
||||||
# check that complex cancellation is avoided so that both
|
|
||||||
# real and imaginary parts have high relative accuracy.
|
|
||||||
# abs_eps should be 0, but has to be set to 1e-205 to pass the
|
|
||||||
# 200-digit case, probably due to slight inaccuracy in the
|
|
||||||
# precomputed input
|
|
||||||
assert (tan(abi).real).ae(mpf(tan_abi_real), abs_eps=1e-205)
|
|
||||||
assert (tan(abi).imag).ae(mpf(tan_abi_imag), abs_eps=1e-205)
|
|
||||||
mp.dps = 460
|
|
||||||
assert str(log(3))[-20:] == '02166121184001409826'
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
# Since str(a) can differ in the last digit from rounded a, and I want
|
|
||||||
# to compare the last digits of big numbers with the results in Mathematica,
|
|
||||||
# I made this hack to get the last 20 digits of rounded a
|
|
||||||
|
|
||||||
def last_digits(a):
|
|
||||||
r = repr(a)
|
|
||||||
s = str(a)
|
|
||||||
#dps = mp.dps
|
|
||||||
#mp.dps += 3
|
|
||||||
m = 10
|
|
||||||
r = r.replace(s[:-m],'')
|
|
||||||
r = r.replace("mpf('",'').replace("')",'')
|
|
||||||
num0 = 0
|
|
||||||
for c in r:
|
|
||||||
if c == '0':
|
|
||||||
num0 += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
b = float(int(r))/10**(len(r) - m)
|
|
||||||
if b >= 10**m - 0.5:
|
|
||||||
raise NotImplementedError
|
|
||||||
n = int(round(b))
|
|
||||||
sn = str(n)
|
|
||||||
s = s[:-m] + '0'*num0 + sn
|
|
||||||
return s[-20:]
|
|
||||||
|
|
||||||
# values checked with Mathematica
|
|
||||||
def test_log_hp():
|
|
||||||
mp.dps = 2000
|
|
||||||
a = mpf(10)**15000/3
|
|
||||||
r = log(a)
|
|
||||||
res = last_digits(r)
|
|
||||||
# Mathematica N[Log[10^15000/3], 2000]
|
|
||||||
# ...7443804441768333470331
|
|
||||||
assert res == '44380444176833347033'
|
|
||||||
|
|
||||||
# see issue 105
|
|
||||||
r = log(mpf(3)/2)
|
|
||||||
# Mathematica N[Log[3/2], 2000]
|
|
||||||
# ...69653749808140753263288
|
|
||||||
res = last_digits(r)
|
|
||||||
assert res == '53749808140753263288'
|
|
||||||
|
|
||||||
mp.dps = 10000
|
|
||||||
r = log(2)
|
|
||||||
res = last_digits(r)
|
|
||||||
# Mathematica N[Log[2], 10000]
|
|
||||||
# ...695615913401856601359655561
|
|
||||||
assert res == '91340185660135965556'
|
|
||||||
r = log(mpf(10)**10/3)
|
|
||||||
res = last_digits(r)
|
|
||||||
# Mathematica N[Log[10^10/3], 10000]
|
|
||||||
# ...587087654020631943060007154
|
|
||||||
assert res == '54020631943060007154', res
|
|
||||||
r = log(mpf(10)**100/3)
|
|
||||||
res = last_digits(r)
|
|
||||||
# Mathematica N[Log[10^100/3], 10000]
|
|
||||||
# ,,,59246336539088351652334666
|
|
||||||
assert res == '36539088351652334666', res
|
|
||||||
mp.dps += 10
|
|
||||||
a = 1 - mpf(1)/10**10
|
|
||||||
mp.dps -= 10
|
|
||||||
r = log(a)
|
|
||||||
res = last_digits(r)
|
|
||||||
# ...3310334360482956137216724048322957404
|
|
||||||
# 372167240483229574038733026370
|
|
||||||
# Mathematica N[Log[1 - 10^-10]*10^10, 10000]
|
|
||||||
# ...60482956137216724048322957404
|
|
||||||
assert res == '37216724048322957404', res
|
|
||||||
mp.dps = 10000
|
|
||||||
mp.dps += 100
|
|
||||||
a = 1 + mpf(1)/10**100
|
|
||||||
mp.dps -= 100
|
|
||||||
|
|
||||||
r = log(a)
|
|
||||||
res = last_digits(+r)
|
|
||||||
# Mathematica N[Log[1 + 10^-100]*10^10, 10030]
|
|
||||||
# ...3994733877377412241546890854692521568292338268273 10^-91
|
|
||||||
assert res == '39947338773774122415', res
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_exp_hp():
|
|
||||||
mp.dps = 4000
|
|
||||||
r = exp(mpf(1)/10)
|
|
||||||
# IntegerPart[N[Exp[1/10] * 10^4000, 4000]]
|
|
||||||
# ...92167105162069688129
|
|
||||||
assert int(r * 10**mp.dps) % 10**20 == 92167105162069688129
|
|
||||||
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_pslq():
|
|
||||||
mp.dps = 15
|
|
||||||
assert pslq([3*pi+4*e/7, pi, e, log(2)]) == [7, -21, -4, 0]
|
|
||||||
assert pslq([4.9999999999999991, 1]) == [1, -5]
|
|
||||||
assert pslq([2,1]) == [1, -2]
|
|
||||||
|
|
||||||
def test_identify():
|
|
||||||
mp.dps = 20
|
|
||||||
assert identify(zeta(4), ['log(2)', 'pi**4']) == '((1/90)*pi**4)'
|
|
||||||
mp.dps = 15
|
|
||||||
assert identify(exp(5)) == 'exp(5)'
|
|
||||||
assert identify(exp(4)) == 'exp(4)'
|
|
||||||
assert identify(log(5)) == 'log(5)'
|
|
||||||
assert identify(exp(3*pi), ['pi']) == 'exp((3*pi))'
|
|
||||||
assert identify(3, full=True) == ['3', '3', '1/(1/3)', 'sqrt(9)',
|
|
||||||
'1/sqrt((1/9))', '(sqrt(12)/2)**2', '1/(sqrt(12)/6)**2']
|
|
||||||
assert identify(pi+1, {'a':+pi}) == '(1 + 1*a)'
|
|
||||||
|
|
@ -1,264 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
mpi_to_str = mp.mpi_to_str
|
|
||||||
mpi_from_str = mp.mpi_from_str
|
|
||||||
|
|
||||||
def test_interval_identity():
|
|
||||||
mp.dps = 15
|
|
||||||
assert mpi(2) == mpi(2, 2)
|
|
||||||
assert mpi(2) != mpi(-2, 2)
|
|
||||||
assert not (mpi(2) != mpi(2, 2))
|
|
||||||
assert mpi(-1, 1) == mpi(-1, 1)
|
|
||||||
assert str(mpi('0.1')) == "[0.099999999999999991673, 0.10000000000000000555]"
|
|
||||||
assert repr(mpi('0.1')) == "mpi(mpf('0.099999999999999992'), mpf('0.10000000000000001'))"
|
|
||||||
u = mpi(-1, 3)
|
|
||||||
assert -1 in u
|
|
||||||
assert 2 in u
|
|
||||||
assert 3 in u
|
|
||||||
assert -1.1 not in u
|
|
||||||
assert 3.1 not in u
|
|
||||||
assert mpi(-1, 3) in u
|
|
||||||
assert mpi(0, 1) in u
|
|
||||||
assert mpi(-1.1, 2) not in u
|
|
||||||
assert mpi(2.5, 3.1) not in u
|
|
||||||
w = mpi(-inf, inf)
|
|
||||||
assert mpi(-5, 5) in w
|
|
||||||
assert mpi(2, inf) in w
|
|
||||||
assert mpi(0, 2) in mpi(0, 10)
|
|
||||||
assert not (3 in mpi(-inf, 0))
|
|
||||||
|
|
||||||
def test_interval_arithmetic():
|
|
||||||
mp.dps = 15
|
|
||||||
assert mpi(2) + mpi(3,4) == mpi(5,6)
|
|
||||||
assert mpi(1, 2)**2 == mpi(1, 4)
|
|
||||||
assert mpi(1) + mpi(0, 1e-50) == mpi(1, mpf('1.0000000000000002'))
|
|
||||||
x = 1 / (1 / mpi(3))
|
|
||||||
assert x.a < 3 < x.b
|
|
||||||
x = mpi(2) ** mpi(0.5)
|
|
||||||
mp.dps += 5
|
|
||||||
sq = sqrt(2)
|
|
||||||
mp.dps -= 5
|
|
||||||
assert x.a < sq < x.b
|
|
||||||
assert mpi(1) / mpi(1, inf)
|
|
||||||
assert mpi(2, 3) / inf == mpi(0, 0)
|
|
||||||
assert mpi(0) / inf == 0
|
|
||||||
assert mpi(0) / 0 == mpi(-inf, inf)
|
|
||||||
assert mpi(inf) / 0 == mpi(-inf, inf)
|
|
||||||
assert mpi(0) * inf == mpi(-inf, inf)
|
|
||||||
assert 1 / mpi(2, inf) == mpi(0, 0.5)
|
|
||||||
assert str((mpi(50, 50) * mpi(-10, -10)) / 3) == \
|
|
||||||
'[-166.66666666666668561, -166.66666666666665719]'
|
|
||||||
assert mpi(0, 4) ** 3 == mpi(0, 64)
|
|
||||||
assert mpi(2,4).mid == 3
|
|
||||||
mp.dps = 30
|
|
||||||
a = mpi(pi)
|
|
||||||
mp.dps = 15
|
|
||||||
b = +a
|
|
||||||
assert b.a < a.a
|
|
||||||
assert b.b > a.b
|
|
||||||
a = mpi(pi)
|
|
||||||
assert a == +a
|
|
||||||
assert abs(mpi(-1,2)) == mpi(0,2)
|
|
||||||
assert abs(mpi(0.5,2)) == mpi(0.5,2)
|
|
||||||
assert abs(mpi(-3,2)) == mpi(0,3)
|
|
||||||
assert abs(mpi(-3,-0.5)) == mpi(0.5,3)
|
|
||||||
assert mpi(0) * mpi(2,3) == mpi(0)
|
|
||||||
assert mpi(2,3) * mpi(0) == mpi(0)
|
|
||||||
assert mpi(1,3).delta == 2
|
|
||||||
assert mpi(1,2) - mpi(3,4) == mpi(-3,-1)
|
|
||||||
assert mpi(-inf,0) - mpi(0,inf) == mpi(-inf,0)
|
|
||||||
assert mpi(-inf,0) - mpi(-inf,inf) == mpi(-inf,inf)
|
|
||||||
assert mpi(0,inf) - mpi(-inf,1) == mpi(-1,inf)
|
|
||||||
|
|
||||||
def test_interval_mul():
|
|
||||||
assert mpi(-1, 0) * inf == mpi(-inf, 0)
|
|
||||||
assert mpi(-1, 0) * -inf == mpi(0, inf)
|
|
||||||
assert mpi(0, 1) * inf == mpi(0, inf)
|
|
||||||
assert mpi(0, 1) * mpi(0, inf) == mpi(0, inf)
|
|
||||||
assert mpi(-1, 1) * inf == mpi(-inf, inf)
|
|
||||||
assert mpi(-1, 1) * mpi(0, inf) == mpi(-inf, inf)
|
|
||||||
assert mpi(-1, 1) * mpi(-inf, inf) == mpi(-inf, inf)
|
|
||||||
assert mpi(-inf, 0) * mpi(0, 1) == mpi(-inf, 0)
|
|
||||||
assert mpi(-inf, 0) * mpi(0, 0) * mpi(-inf, 0)
|
|
||||||
assert mpi(-inf, 0) * mpi(-inf, inf) == mpi(-inf, inf)
|
|
||||||
assert mpi(-5,0)*mpi(-32,28) == mpi(-140,160)
|
|
||||||
assert mpi(2,3) * mpi(-1,2) == mpi(-3,6)
|
|
||||||
# Should be undefined?
|
|
||||||
assert mpi(inf, inf) * 0 == mpi(-inf, inf)
|
|
||||||
assert mpi(-inf, -inf) * 0 == mpi(-inf, inf)
|
|
||||||
assert mpi(0) * mpi(-inf,2) == mpi(-inf,inf)
|
|
||||||
assert mpi(0) * mpi(-2,inf) == mpi(-inf,inf)
|
|
||||||
assert mpi(-2,inf) * mpi(0) == mpi(-inf,inf)
|
|
||||||
assert mpi(-inf,2) * mpi(0) == mpi(-inf,inf)
|
|
||||||
|
|
||||||
def test_interval_pow():
|
|
||||||
assert mpi(3)**2 == mpi(9, 9)
|
|
||||||
assert mpi(-3)**2 == mpi(9, 9)
|
|
||||||
assert mpi(-3, 1)**2 == mpi(0, 9)
|
|
||||||
assert mpi(-3, -1)**2 == mpi(1, 9)
|
|
||||||
assert mpi(-3, -1)**3 == mpi(-27, -1)
|
|
||||||
assert mpi(-3, 1)**3 == mpi(-27, 1)
|
|
||||||
assert mpi(-2, 3)**2 == mpi(0, 9)
|
|
||||||
assert mpi(-3, 2)**2 == mpi(0, 9)
|
|
||||||
assert mpi(4) ** -1 == mpi(0.25, 0.25)
|
|
||||||
assert mpi(-4) ** -1 == mpi(-0.25, -0.25)
|
|
||||||
assert mpi(4) ** -2 == mpi(0.0625, 0.0625)
|
|
||||||
assert mpi(-4) ** -2 == mpi(0.0625, 0.0625)
|
|
||||||
assert mpi(0, 1) ** inf == mpi(0, 1)
|
|
||||||
assert mpi(0, 1) ** -inf == mpi(1, inf)
|
|
||||||
assert mpi(0, inf) ** inf == mpi(0, inf)
|
|
||||||
assert mpi(0, inf) ** -inf == mpi(0, inf)
|
|
||||||
assert mpi(1, inf) ** inf == mpi(1, inf)
|
|
||||||
assert mpi(1, inf) ** -inf == mpi(0, 1)
|
|
||||||
assert mpi(2, 3) ** 1 == mpi(2, 3)
|
|
||||||
assert mpi(2, 3) ** 0 == 1
|
|
||||||
assert mpi(1,3) ** mpi(2) == mpi(1,9)
|
|
||||||
|
|
||||||
def test_interval_sqrt():
|
|
||||||
assert mpi(4) ** 0.5 == mpi(2)
|
|
||||||
|
|
||||||
def test_interval_div():
|
|
||||||
assert mpi(0.5, 1) / mpi(-1, 0) == mpi(-inf, -0.5)
|
|
||||||
assert mpi(0, 1) / mpi(0, 1) == mpi(0, inf)
|
|
||||||
assert mpi(inf, inf) / mpi(inf, inf) == mpi(0, inf)
|
|
||||||
assert mpi(inf, inf) / mpi(2, inf) == mpi(0, inf)
|
|
||||||
assert mpi(inf, inf) / mpi(2, 2) == mpi(inf, inf)
|
|
||||||
assert mpi(0, inf) / mpi(2, inf) == mpi(0, inf)
|
|
||||||
assert mpi(0, inf) / mpi(2, 2) == mpi(0, inf)
|
|
||||||
assert mpi(2, inf) / mpi(2, 2) == mpi(1, inf)
|
|
||||||
assert mpi(2, inf) / mpi(2, inf) == mpi(0, inf)
|
|
||||||
assert mpi(-4, 8) / mpi(1, inf) == mpi(-4, 8)
|
|
||||||
assert mpi(-4, 8) / mpi(0.5, inf) == mpi(-8, 16)
|
|
||||||
assert mpi(-inf, 8) / mpi(0.5, inf) == mpi(-inf, 16)
|
|
||||||
assert mpi(-inf, inf) / mpi(0.5, inf) == mpi(-inf, inf)
|
|
||||||
assert mpi(8, inf) / mpi(0.5, inf) == mpi(0, inf)
|
|
||||||
assert mpi(-8, inf) / mpi(0.5, inf) == mpi(-16, inf)
|
|
||||||
assert mpi(-4, 8) / mpi(inf, inf) == mpi(0, 0)
|
|
||||||
assert mpi(0, 8) / mpi(inf, inf) == mpi(0, 0)
|
|
||||||
assert mpi(0, 0) / mpi(inf, inf) == mpi(0, 0)
|
|
||||||
assert mpi(-inf, 0) / mpi(inf, inf) == mpi(-inf, 0)
|
|
||||||
assert mpi(-inf, 8) / mpi(inf, inf) == mpi(-inf, 0)
|
|
||||||
assert mpi(-inf, inf) / mpi(inf, inf) == mpi(-inf, inf)
|
|
||||||
assert mpi(-8, inf) / mpi(inf, inf) == mpi(0, inf)
|
|
||||||
assert mpi(0, inf) / mpi(inf, inf) == mpi(0, inf)
|
|
||||||
assert mpi(8, inf) / mpi(inf, inf) == mpi(0, inf)
|
|
||||||
assert mpi(inf, inf) / mpi(inf, inf) == mpi(0, inf)
|
|
||||||
assert mpi(-1, 2) / mpi(0, 1) == mpi(-inf, +inf)
|
|
||||||
assert mpi(0, 1) / mpi(0, 1) == mpi(0.0, +inf)
|
|
||||||
assert mpi(-1, 0) / mpi(0, 1) == mpi(-inf, 0.0)
|
|
||||||
assert mpi(-0.5, -0.25) / mpi(0, 1) == mpi(-inf, -0.25)
|
|
||||||
assert mpi(0.5, 1) / mpi(0, 1) == mpi(0.5, +inf)
|
|
||||||
assert mpi(0.5, 4) / mpi(0, 1) == mpi(0.5, +inf)
|
|
||||||
assert mpi(-1, -0.5) / mpi(0, 1) == mpi(-inf, -0.5)
|
|
||||||
assert mpi(-4, -0.5) / mpi(0, 1) == mpi(-inf, -0.5)
|
|
||||||
assert mpi(-1, 2) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(0, 1) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(-1, 0) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(-0.5, -0.25) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(0.5, 1) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(0.5, 4) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(-1, -0.5) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(-4, -0.5) / mpi(-2, 0.5) == mpi(-inf, +inf)
|
|
||||||
assert mpi(-1, 2) / mpi(-1, 0) == mpi(-inf, +inf)
|
|
||||||
assert mpi(0, 1) / mpi(-1, 0) == mpi(-inf, 0.0)
|
|
||||||
assert mpi(-1, 0) / mpi(-1, 0) == mpi(0.0, +inf)
|
|
||||||
assert mpi(-0.5, -0.25) / mpi(-1, 0) == mpi(0.25, +inf)
|
|
||||||
assert mpi(0.5, 1) / mpi(-1, 0) == mpi(-inf, -0.5)
|
|
||||||
assert mpi(0.5, 4) / mpi(-1, 0) == mpi(-inf, -0.5)
|
|
||||||
assert mpi(-1, -0.5) / mpi(-1, 0) == mpi(0.5, +inf)
|
|
||||||
assert mpi(-4, -0.5) / mpi(-1, 0) == mpi(0.5, +inf)
|
|
||||||
assert mpi(-1, 2) / mpi(0.5, 1) == mpi(-2.0, 4.0)
|
|
||||||
assert mpi(0, 1) / mpi(0.5, 1) == mpi(0.0, 2.0)
|
|
||||||
assert mpi(-1, 0) / mpi(0.5, 1) == mpi(-2.0, 0.0)
|
|
||||||
assert mpi(-0.5, -0.25) / mpi(0.5, 1) == mpi(-1.0, -0.25)
|
|
||||||
assert mpi(0.5, 1) / mpi(0.5, 1) == mpi(0.5, 2.0)
|
|
||||||
assert mpi(0.5, 4) / mpi(0.5, 1) == mpi(0.5, 8.0)
|
|
||||||
assert mpi(-1, -0.5) / mpi(0.5, 1) == mpi(-2.0, -0.5)
|
|
||||||
assert mpi(-4, -0.5) / mpi(0.5, 1) == mpi(-8.0, -0.5)
|
|
||||||
assert mpi(-1, 2) / mpi(-2, -0.5) == mpi(-4.0, 2.0)
|
|
||||||
assert mpi(0, 1) / mpi(-2, -0.5) == mpi(-2.0, 0.0)
|
|
||||||
assert mpi(-1, 0) / mpi(-2, -0.5) == mpi(0.0, 2.0)
|
|
||||||
assert mpi(-0.5, -0.25) / mpi(-2, -0.5) == mpi(0.125, 1.0)
|
|
||||||
assert mpi(0.5, 1) / mpi(-2, -0.5) == mpi(-2.0, -0.25)
|
|
||||||
assert mpi(0.5, 4) / mpi(-2, -0.5) == mpi(-8.0, -0.25)
|
|
||||||
assert mpi(-1, -0.5) / mpi(-2, -0.5) == mpi(0.25, 2.0)
|
|
||||||
assert mpi(-4, -0.5) / mpi(-2, -0.5) == mpi(0.25, 8.0)
|
|
||||||
# Should be undefined?
|
|
||||||
assert mpi(0, 0) / mpi(0, 0) == mpi(-inf, inf)
|
|
||||||
assert mpi(0, 0) / mpi(0, 1) == mpi(-inf, inf)
|
|
||||||
|
|
||||||
def test_interval_cos_sin():
|
|
||||||
mp.dps = 15
|
|
||||||
# Around 0
|
|
||||||
assert cos(mpi(0)) == 1
|
|
||||||
assert sin(mpi(0)) == 0
|
|
||||||
assert cos(mpi(0,1)) == mpi(0.54030230586813965399, 1.0)
|
|
||||||
assert sin(mpi(0,1)) == mpi(0, 0.8414709848078966159)
|
|
||||||
assert cos(mpi(1,2)) == mpi(-0.4161468365471424069, 0.54030230586813976501)
|
|
||||||
assert sin(mpi(1,2)) == mpi(0.84147098480789650488, 1.0)
|
|
||||||
assert sin(mpi(1,2.5)) == mpi(0.59847214410395643824, 1.0)
|
|
||||||
assert cos(mpi(-1, 1)) == mpi(0.54030230586813965399, 1.0)
|
|
||||||
assert cos(mpi(-1, 0.5)) == mpi(0.54030230586813965399, 1.0)
|
|
||||||
assert cos(mpi(-1, 1.5)) == mpi(0.070737201667702906405, 1.0)
|
|
||||||
assert sin(mpi(-1,1)) == mpi(-0.8414709848078966159, 0.8414709848078966159)
|
|
||||||
assert sin(mpi(-1,0.5)) == mpi(-0.8414709848078966159, 0.47942553860420300538)
|
|
||||||
assert sin(mpi(-1,1e-100)) == mpi(-0.8414709848078966159, 1.00000000000000002e-100)
|
|
||||||
assert sin(mpi(-2e-100,1e-100)) == mpi(-2.00000000000000004e-100, 1.00000000000000002e-100)
|
|
||||||
# Same interval
|
|
||||||
assert cos(mpi(2, 2.5)) == mpi(-0.80114361554693380718, -0.41614683654714235139)
|
|
||||||
assert cos(mpi(3.5, 4)) == mpi(-0.93645668729079634129, -0.65364362086361182946)
|
|
||||||
assert cos(mpi(5, 5.5)) == mpi(0.28366218546322624627, 0.70866977429126010168)
|
|
||||||
assert sin(mpi(2, 2.5)) == mpi(0.59847214410395654927, 0.90929742682568170942)
|
|
||||||
assert sin(mpi(3.5, 4)) == mpi(-0.75680249530792831347, -0.35078322768961983646)
|
|
||||||
assert sin(mpi(5, 5.5)) == mpi(-0.95892427466313856499, -0.70554032557039181306)
|
|
||||||
# Higher roots
|
|
||||||
mp.dps = 55
|
|
||||||
w = 4*10**50 + mpf(0.5)
|
|
||||||
for p in [15, 40, 80]:
|
|
||||||
mp.dps = p
|
|
||||||
assert 0 in sin(4*mpi(pi))
|
|
||||||
assert 0 in sin(4*10**50*mpi(pi))
|
|
||||||
assert 0 in cos((4+0.5)*mpi(pi))
|
|
||||||
assert 0 in cos(w*mpi(pi))
|
|
||||||
assert 1 in cos(4*mpi(pi))
|
|
||||||
assert 1 in cos(4*10**50*mpi(pi))
|
|
||||||
mp.dps = 15
|
|
||||||
assert cos(mpi(2,inf)) == mpi(-1,1)
|
|
||||||
assert sin(mpi(2,inf)) == mpi(-1,1)
|
|
||||||
assert cos(mpi(-inf,2)) == mpi(-1,1)
|
|
||||||
assert sin(mpi(-inf,2)) == mpi(-1,1)
|
|
||||||
u = tan(mpi(0.5,1))
|
|
||||||
assert u.a.ae(tan(0.5))
|
|
||||||
assert u.b.ae(tan(1))
|
|
||||||
v = cot(mpi(0.5,1))
|
|
||||||
assert v.a.ae(cot(1))
|
|
||||||
assert v.b.ae(cot(0.5))
|
|
||||||
|
|
||||||
def test_mpi_to_str():
|
|
||||||
mp.dps = 30
|
|
||||||
x = mpi(1, 2)
|
|
||||||
# FIXME: error_dps should not be necessary
|
|
||||||
assert mpi_to_str(x, mode='plusminus', error_dps=6) == '1.5 +- 0.5'
|
|
||||||
assert mpi_to_str(x, mode='plusminus', use_spaces=False, error_dps=6
|
|
||||||
) == '1.5+-0.5'
|
|
||||||
assert mpi_to_str(x, mode='percent') == '1.5 (33.33%)'
|
|
||||||
assert mpi_to_str(x, mode='brackets', use_spaces=False) == '[1.0,2.0]'
|
|
||||||
assert mpi_to_str(x, mode='brackets' , brackets=('<', '>')) == '<1.0, 2.0>'
|
|
||||||
x = mpi('5.2582327113062393041', '5.2582327113062749951')
|
|
||||||
assert (mpi_to_str(x, mode='diff') ==
|
|
||||||
'5.2582327113062[393041, 749951]')
|
|
||||||
assert (mpi_to_str(cos(mpi(1)), mode='diff', use_spaces=False) ==
|
|
||||||
'0.54030230586813971740093660744[2955,3053]')
|
|
||||||
assert (mpi_to_str(mpi('1e123', '1e129'), mode='diff') ==
|
|
||||||
'[1.0e+123, 1.0e+129]')
|
|
||||||
assert (mpi_to_str(exp(mpi('5000.1')), mode='diff') ==
|
|
||||||
'3.2797365856787867069110487[0926, 1191]e+2171')
|
|
||||||
|
|
||||||
def test_mpi_from_str():
|
|
||||||
assert mpi_from_str('1.5 +- 0.5') == mpi(mpf('1.0'), mpf('2.0'))
|
|
||||||
assert (mpi_from_str('1.5 (33.33333333333333333333333333333%)') ==
|
|
||||||
mpi(mpf(1), mpf(2)))
|
|
||||||
assert mpi_from_str('[1, 2]') == mpi(1, 2)
|
|
||||||
assert mpi_from_str('1[2, 3]') == mpi(12, 13)
|
|
||||||
assert mpi_from_str('1.[23,46]e-8') == mpi('1.23e-8', '1.46e-8')
|
|
||||||
assert mpi_from_str('12[3.4,5.9]e4') == mpi('123.4e+4', '125.9e4')
|
|
||||||
|
|
@ -1,243 +0,0 @@
|
||||||
# TODO: don't use round
|
|
||||||
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
# XXX: these shouldn't be visible(?)
|
|
||||||
LU_decomp = mp.LU_decomp
|
|
||||||
L_solve = mp.L_solve
|
|
||||||
U_solve = mp.U_solve
|
|
||||||
householder = mp.householder
|
|
||||||
improve_solution = mp.improve_solution
|
|
||||||
|
|
||||||
A1 = matrix([[3, 1, 6],
|
|
||||||
[2, 1, 3],
|
|
||||||
[1, 1, 1]])
|
|
||||||
b1 = [2, 7, 4]
|
|
||||||
|
|
||||||
A2 = matrix([[ 2, -1, -1, 2],
|
|
||||||
[ 6, -2, 3, -1],
|
|
||||||
[-4, 2, 3, -2],
|
|
||||||
[ 2, 0, 4, -3]])
|
|
||||||
b2 = [3, -3, -2, -1]
|
|
||||||
|
|
||||||
A3 = matrix([[ 1, 0, -1, -1, 0],
|
|
||||||
[ 0, 1, 1, 0, -1],
|
|
||||||
[ 4, -5, 2, 0, 0],
|
|
||||||
[ 0, 0, -2, 9,-12],
|
|
||||||
[ 0, 5, 0, 0, 12]])
|
|
||||||
b3 = [0, 0, 0, 0, 50]
|
|
||||||
|
|
||||||
A4 = matrix([[10.235, -4.56, 0., -0.035, 5.67],
|
|
||||||
[-2.463, 1.27, 3.97, -8.63, 1.08],
|
|
||||||
[-6.58, 0.86, -0.257, 9.32, -43.6 ],
|
|
||||||
[ 9.83, 7.39, -17.25, 0.036, 24.86],
|
|
||||||
[-9.31, 34.9, 78.56, 1.07, 65.8 ]])
|
|
||||||
b4 = [8.95, 20.54, 7.42, 5.60, 58.43]
|
|
||||||
|
|
||||||
A5 = matrix([[ 1, 2, -4],
|
|
||||||
[-2, -3, 5],
|
|
||||||
[ 3, 5, -8]])
|
|
||||||
|
|
||||||
A6 = matrix([[ 1.377360, 2.481400, 5.359190],
|
|
||||||
[ 2.679280, -1.229560, 25.560210],
|
|
||||||
[-1.225280+1.e6, 9.910180, -35.049900-1.e6]])
|
|
||||||
b6 = [23.500000, -15.760000, 2.340000]
|
|
||||||
|
|
||||||
A7 = matrix([[1, -0.5],
|
|
||||||
[2, 1],
|
|
||||||
[-2, 6]])
|
|
||||||
b7 = [3, 2, -4]
|
|
||||||
|
|
||||||
A8 = matrix([[1, 2, 3],
|
|
||||||
[-1, 0, 1],
|
|
||||||
[-1, -2, -1],
|
|
||||||
[1, 0, -1]])
|
|
||||||
b8 = [1, 2, 3, 4]
|
|
||||||
|
|
||||||
A9 = matrix([[ 4, 2, -2],
|
|
||||||
[ 2, 5, -4],
|
|
||||||
[-2, -4, 5.5]])
|
|
||||||
b9 = [10, 16, -15.5]
|
|
||||||
|
|
||||||
A10 = matrix([[1.0 + 1.0j, 2.0, 2.0],
|
|
||||||
[4.0, 5.0, 6.0],
|
|
||||||
[7.0, 8.0, 9.0]])
|
|
||||||
b10 = [1.0, 1.0 + 1.0j, 1.0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_LU_decomp():
|
|
||||||
A = A3.copy()
|
|
||||||
b = b3
|
|
||||||
A, p = LU_decomp(A)
|
|
||||||
y = L_solve(A, b, p)
|
|
||||||
x = U_solve(A, y)
|
|
||||||
assert p == [2, 1, 2, 3]
|
|
||||||
assert [round(i, 14) for i in x] == [3.78953107960742, 2.9989094874591098,
|
|
||||||
-0.081788440567070006, 3.8713195201744801, 2.9171210468920399]
|
|
||||||
A = A4.copy()
|
|
||||||
b = b4
|
|
||||||
A, p = LU_decomp(A)
|
|
||||||
y = L_solve(A, b, p)
|
|
||||||
x = U_solve(A, y)
|
|
||||||
assert p == [0, 3, 4, 3]
|
|
||||||
assert [round(i, 14) for i in x] == [2.6383625899619201, 2.6643834462368399,
|
|
||||||
0.79208015947958998, -2.5088376454101899, -1.0567657691375001]
|
|
||||||
A = randmatrix(3)
|
|
||||||
bak = A.copy()
|
|
||||||
LU_decomp(A, overwrite=1)
|
|
||||||
assert A != bak
|
|
||||||
|
|
||||||
def test_inverse():
|
|
||||||
for A in [A1, A2, A5]:
|
|
||||||
inv = inverse(A)
|
|
||||||
assert mnorm(A*inv - eye(A.rows), 1) < 1.e-14
|
|
||||||
|
|
||||||
def test_householder():
|
|
||||||
mp.dps = 15
|
|
||||||
A, b = A8, b8
|
|
||||||
H, p, x, r = householder(extend(A, b))
|
|
||||||
assert H == matrix(
|
|
||||||
[[mpf('3.0'), mpf('-2.0'), mpf('-1.0'), 0],
|
|
||||||
[-1.0,mpf('3.333333333333333'),mpf('-2.9999999999999991'),mpf('2.0')],
|
|
||||||
[-1.0, mpf('-0.66666666666666674'),mpf('2.8142135623730948'),
|
|
||||||
mpf('-2.8284271247461898')],
|
|
||||||
[1.0, mpf('-1.3333333333333333'),mpf('-0.20000000000000018'),
|
|
||||||
mpf('4.2426406871192857')]])
|
|
||||||
assert p == [-2, -2, mpf('-1.4142135623730949')]
|
|
||||||
assert round(norm(r, 2), 10) == 4.2426406870999998
|
|
||||||
|
|
||||||
y = [102.102, 58.344, 36.463, 24.310, 17.017, 12.376, 9.282, 7.140, 5.610,
|
|
||||||
4.488, 3.6465, 3.003]
|
|
||||||
|
|
||||||
def coeff(n):
|
|
||||||
# similiar to Hilbert matrix
|
|
||||||
A = []
|
|
||||||
for i in xrange(1, 13):
|
|
||||||
A.append([1. / (i + j - 1) for j in xrange(1, n + 1)])
|
|
||||||
return matrix(A)
|
|
||||||
|
|
||||||
residuals = []
|
|
||||||
refres = []
|
|
||||||
for n in xrange(2, 7):
|
|
||||||
A = coeff(n)
|
|
||||||
H, p, x, r = householder(extend(A, y))
|
|
||||||
x = matrix(x)
|
|
||||||
y = matrix(y)
|
|
||||||
residuals.append(norm(r, 2))
|
|
||||||
refres.append(norm(residual(A, x, y), 2))
|
|
||||||
assert [round(res, 10) for res in residuals] == [15.1733888877,
|
|
||||||
0.82378073210000002, 0.302645887, 0.0260109244,
|
|
||||||
0.00058653999999999998]
|
|
||||||
assert norm(matrix(residuals) - matrix(refres), inf) < 1.e-13
|
|
||||||
|
|
||||||
def test_factorization():
|
|
||||||
A = randmatrix(5)
|
|
||||||
P, L, U = lu(A)
|
|
||||||
assert mnorm(P*A - L*U, 1) < 1.e-15
|
|
||||||
|
|
||||||
def test_solve():
|
|
||||||
assert norm(residual(A6, lu_solve(A6, b6), b6), inf) < 1.e-10
|
|
||||||
assert norm(residual(A7, lu_solve(A7, b7), b7), inf) < 1.5
|
|
||||||
assert norm(residual(A8, lu_solve(A8, b8), b8), inf) <= 3 + 1.e-10
|
|
||||||
assert norm(residual(A6, qr_solve(A6, b6)[0], b6), inf) < 1.e-10
|
|
||||||
assert norm(residual(A7, qr_solve(A7, b7)[0], b7), inf) < 1.5
|
|
||||||
assert norm(residual(A8, qr_solve(A8, b8)[0], b8), 2) <= 4.3
|
|
||||||
assert norm(residual(A10, lu_solve(A10, b10), b10), 2) < 1.e-10
|
|
||||||
assert norm(residual(A10, qr_solve(A10, b10)[0], b10), 2) < 1.e-10
|
|
||||||
|
|
||||||
def test_solve_overdet_complex():
|
|
||||||
A = matrix([[1, 2j], [3, 4j], [5, 6]])
|
|
||||||
b = matrix([1 + j, 2, -j])
|
|
||||||
assert norm(residual(A, lu_solve(A, b), b)) < 1.0208
|
|
||||||
|
|
||||||
def test_singular():
|
|
||||||
mp.dps = 15
|
|
||||||
A = [[5.6, 1.2], [7./15, .1]]
|
|
||||||
B = repr(zeros(2))
|
|
||||||
b = [1, 2]
|
|
||||||
def _assert_ZeroDivisionError(statement):
|
|
||||||
try:
|
|
||||||
eval(statement)
|
|
||||||
assert False
|
|
||||||
except (ZeroDivisionError, ValueError):
|
|
||||||
pass
|
|
||||||
for i in ['lu_solve(%s, %s)' % (A, b), 'lu_solve(%s, %s)' % (B, b),
|
|
||||||
'qr_solve(%s, %s)' % (A, b), 'qr_solve(%s, %s)' % (B, b)]:
|
|
||||||
_assert_ZeroDivisionError(i)
|
|
||||||
|
|
||||||
def test_cholesky():
|
|
||||||
assert fp.cholesky(fp.matrix(A9)) == fp.matrix([[2, 0, 0], [1, 2, 0], [-1, -3/2, 3/2]])
|
|
||||||
x = fp.cholesky_solve(A9, b9)
|
|
||||||
assert fp.norm(fp.residual(A9, x, b9), fp.inf) == 0
|
|
||||||
|
|
||||||
def test_det():
|
|
||||||
assert det(A1) == 1
|
|
||||||
assert round(det(A2), 14) == 8
|
|
||||||
assert round(det(A3)) == 1834
|
|
||||||
assert round(det(A4)) == 4443376
|
|
||||||
assert det(A5) == 1
|
|
||||||
assert round(det(A6)) == 78356463
|
|
||||||
assert det(zeros(3)) == 0
|
|
||||||
|
|
||||||
def test_cond():
|
|
||||||
mp.dps = 15
|
|
||||||
A = matrix([[1.2969, 0.8648], [0.2161, 0.1441]])
|
|
||||||
assert cond(A, lambda x: mnorm(x,1)) == mpf('327065209.73817754')
|
|
||||||
assert cond(A, lambda x: mnorm(x,inf)) == mpf('327065209.73817754')
|
|
||||||
assert cond(A, lambda x: mnorm(x,'F')) == mpf('249729266.80008656')
|
|
||||||
|
|
||||||
@extradps(50)
|
|
||||||
def test_precision():
|
|
||||||
A = randmatrix(10, 10)
|
|
||||||
assert mnorm(inverse(inverse(A)) - A, 1) < 1.e-45
|
|
||||||
|
|
||||||
def test_interval_matrix():
|
|
||||||
a = matrix([['0.1','0.3','1.0'],['7.1','5.5','4.8'],['3.2','4.4','5.6']],
|
|
||||||
force_type=mpi)
|
|
||||||
b = matrix(['4','0.6','0.5'], force_type=mpi)
|
|
||||||
c = lu_solve(a, b)
|
|
||||||
assert c[0].delta < 1e-13
|
|
||||||
assert c[1].delta < 1e-13
|
|
||||||
assert c[2].delta < 1e-13
|
|
||||||
assert 5.25823271130625686059275 in c[0]
|
|
||||||
assert -13.155049396267837541163 in c[1]
|
|
||||||
assert 7.42069154774972557628979 in c[2]
|
|
||||||
|
|
||||||
def test_LU_cache():
|
|
||||||
A = randmatrix(3)
|
|
||||||
LU = LU_decomp(A)
|
|
||||||
assert A._LU == LU_decomp(A)
|
|
||||||
A[0,0] = -1000
|
|
||||||
assert A._LU is None
|
|
||||||
|
|
||||||
def test_improve_solution():
|
|
||||||
A = randmatrix(5, min=1e-20, max=1e20)
|
|
||||||
b = randmatrix(5, 1, min=-1000, max=1000)
|
|
||||||
x1 = lu_solve(A, b) + randmatrix(5, 1, min=-1e-5, max=1.e-5)
|
|
||||||
x2 = improve_solution(A, x1, b)
|
|
||||||
assert norm(residual(A, x2, b), 2) < norm(residual(A, x1, b), 2)
|
|
||||||
|
|
||||||
def test_exp_pade():
|
|
||||||
for i in range(3):
|
|
||||||
dps = 15
|
|
||||||
extra = 5
|
|
||||||
mp.dps = dps + extra
|
|
||||||
dm = 0
|
|
||||||
while not dm:
|
|
||||||
m = randmatrix(3)
|
|
||||||
dm = det(m)
|
|
||||||
m = m/dm
|
|
||||||
a = diag([1,2,3])
|
|
||||||
a1 = m**-1 * a * m
|
|
||||||
mp.dps = dps
|
|
||||||
e1 = expm(a1, method='pade')
|
|
||||||
mp.dps = dps + extra
|
|
||||||
e2 = m * a1 * m**-1
|
|
||||||
d = e2 - a
|
|
||||||
#print d
|
|
||||||
mp.dps = dps
|
|
||||||
assert norm(d, inf).ae(0)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
|
|
@ -1,144 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_matrix_basic():
|
|
||||||
A1 = matrix(3)
|
|
||||||
for i in xrange(3):
|
|
||||||
A1[i,i] = 1
|
|
||||||
assert A1 == eye(3)
|
|
||||||
assert A1 == matrix(A1)
|
|
||||||
A2 = matrix(3, 2)
|
|
||||||
assert not A2._matrix__data
|
|
||||||
A3 = matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
assert list(A3) == range(1, 10)
|
|
||||||
A3[1,1] = 0
|
|
||||||
assert not (1, 1) in A3._matrix__data
|
|
||||||
A4 = matrix([[1, 2, 3], [4, 5, 6]])
|
|
||||||
A5 = matrix([[6, -1], [3, 2], [0, -3]])
|
|
||||||
assert A4 * A5 == matrix([[12, -6], [39, -12]])
|
|
||||||
assert A1 * A3 == A3 * A1 == A3
|
|
||||||
try:
|
|
||||||
A2 * A2
|
|
||||||
assert False
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
l = [[10, 20, 30], [40, 0, 60], [70, 80, 90]]
|
|
||||||
A6 = matrix(l)
|
|
||||||
assert A6.tolist() == l
|
|
||||||
assert A6 == eval(repr(A6))
|
|
||||||
A6 = matrix(A6, force_type=float)
|
|
||||||
assert A6 == eval(repr(A6))
|
|
||||||
assert A6*1j == eval(repr(A6*1j))
|
|
||||||
assert A3 * 10 == 10 * A3 == A6
|
|
||||||
assert A2.rows == 3
|
|
||||||
assert A2.cols == 2
|
|
||||||
A3.rows = 2
|
|
||||||
A3.cols = 2
|
|
||||||
assert len(A3._matrix__data) == 3
|
|
||||||
assert A4 + A4 == 2*A4
|
|
||||||
try:
|
|
||||||
A4 + A2
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
assert sum(A1 - A1) == 0
|
|
||||||
A7 = matrix([[1, 2], [3, 4], [5, 6], [7, 8]])
|
|
||||||
x = matrix([10, -10])
|
|
||||||
assert A7*x == matrix([-10, -10, -10, -10])
|
|
||||||
A8 = ones(5)
|
|
||||||
assert sum((A8 + 1) - (2 - zeros(5))) == 0
|
|
||||||
assert (1 + ones(4)) / 2 - 1 == zeros(4)
|
|
||||||
assert eye(3)**10 == eye(3)
|
|
||||||
try:
|
|
||||||
A7**2
|
|
||||||
assert False
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
A9 = randmatrix(3)
|
|
||||||
A10 = matrix(A9)
|
|
||||||
A9[0,0] = -100
|
|
||||||
assert A9 != A10
|
|
||||||
A11 = matrix(randmatrix(2, 3), force_type=mpi)
|
|
||||||
for a in A11:
|
|
||||||
assert isinstance(a, mpi)
|
|
||||||
assert nstr(A9)
|
|
||||||
|
|
||||||
def test_matrix_power():
|
|
||||||
A = matrix([[1, 2], [3, 4]])
|
|
||||||
assert A**2 == A*A
|
|
||||||
assert A**3 == A*A*A
|
|
||||||
assert A**-1 == inverse(A)
|
|
||||||
assert A**-2 == inverse(A*A)
|
|
||||||
|
|
||||||
def test_matrix_transform():
|
|
||||||
A = matrix([[1, 2], [3, 4], [5, 6]])
|
|
||||||
assert A.T == A.transpose() == matrix([[1, 3, 5], [2, 4, 6]])
|
|
||||||
swap_row(A, 1, 2)
|
|
||||||
assert A == matrix([[1, 2], [5, 6], [3, 4]])
|
|
||||||
l = [1, 2]
|
|
||||||
swap_row(l, 0, 1)
|
|
||||||
assert l == [2, 1]
|
|
||||||
assert extend(eye(3), [1,2,3]) == matrix([[1,0,0,1],[0,1,0,2],[0,0,1,3]])
|
|
||||||
|
|
||||||
def test_matrix_conjugate():
|
|
||||||
A = matrix([[1 + j, 0], [2, j]])
|
|
||||||
assert A.conjugate() == matrix([[mpc(1, -1), 0], [2, mpc(0, -1)]])
|
|
||||||
assert A.transpose_conj() == A.H == matrix([[mpc(1, -1), 2],
|
|
||||||
[0, mpc(0, -1)]])
|
|
||||||
|
|
||||||
def test_matrix_creation():
|
|
||||||
assert diag([1, 2, 3]) == matrix([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
|
|
||||||
A1 = ones(2, 3)
|
|
||||||
assert A1.rows == 2 and A1.cols == 3
|
|
||||||
for a in A1:
|
|
||||||
assert a == 1
|
|
||||||
A2 = zeros(3, 2)
|
|
||||||
assert A2.rows == 3 and A2.cols == 2
|
|
||||||
for a in A2:
|
|
||||||
assert a == 0
|
|
||||||
assert randmatrix(10) != randmatrix(10)
|
|
||||||
one = mpf(1)
|
|
||||||
assert hilbert(3) == matrix([[one, one/2, one/3],
|
|
||||||
[one/2, one/3, one/4],
|
|
||||||
[one/3, one/4, one/5]])
|
|
||||||
|
|
||||||
def test_norms():
|
|
||||||
# matrix norms
|
|
||||||
A = matrix([[1, -2], [-3, -1], [2, 1]])
|
|
||||||
assert mnorm(A,1) == 6
|
|
||||||
assert mnorm(A,inf) == 4
|
|
||||||
assert mnorm(A,'F') == sqrt(20)
|
|
||||||
# vector norms
|
|
||||||
assert norm(-3) == 3
|
|
||||||
x = [1, -2, 7, -12]
|
|
||||||
assert norm(x, 1) == 22
|
|
||||||
assert round(norm(x, 2), 10) == 14.0712472795
|
|
||||||
assert round(norm(x, 10), 10) == 12.0054633727
|
|
||||||
assert norm(x, inf) == 12
|
|
||||||
|
|
||||||
def test_vector():
|
|
||||||
x = matrix([0, 1, 2, 3, 4])
|
|
||||||
assert x == matrix([[0], [1], [2], [3], [4]])
|
|
||||||
assert x[3] == 3
|
|
||||||
assert len(x._matrix__data) == 4
|
|
||||||
assert list(x) == range(5)
|
|
||||||
x[0] = -10
|
|
||||||
x[4] = 0
|
|
||||||
assert x[0] == -10
|
|
||||||
assert len(x) == len(x.T) == 5
|
|
||||||
assert x.T*x == matrix([[114]])
|
|
||||||
|
|
||||||
def test_matrix_copy():
|
|
||||||
A = ones(6)
|
|
||||||
B = A.copy()
|
|
||||||
assert A == B
|
|
||||||
B[0,0] = 0
|
|
||||||
assert A != B
|
|
||||||
|
|
||||||
def test_matrix_numpy():
|
|
||||||
try:
|
|
||||||
import numpy
|
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
l = [[1, 2], [3, 4], [5, 6]]
|
|
||||||
a = numpy.matrix(l)
|
|
||||||
assert matrix(l) == matrix(a)
|
|
||||||
|
|
||||||
|
|
@ -1,98 +0,0 @@
|
||||||
from mpmath.libmp import *
|
|
||||||
from mpmath import *
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------
|
|
||||||
# Low-level tests
|
|
||||||
#
|
|
||||||
|
|
||||||
# Advanced rounding test
|
|
||||||
def test_add_rounding():
|
|
||||||
mp.dps = 15
|
|
||||||
a = from_float(1e-50)
|
|
||||||
assert mpf_sub(mpf_add(fone, a, 53, round_up), fone, 53, round_up) == from_float(2.2204460492503131e-16)
|
|
||||||
assert mpf_sub(fone, a, 53, round_up) == fone
|
|
||||||
assert mpf_sub(fone, mpf_sub(fone, a, 53, round_down), 53, round_down) == from_float(1.1102230246251565e-16)
|
|
||||||
assert mpf_add(fone, a, 53, round_down) == fone
|
|
||||||
|
|
||||||
def test_almost_equal():
|
|
||||||
assert mpf(1.2).ae(mpf(1.20000001), 1e-7)
|
|
||||||
assert not mpf(1.2).ae(mpf(1.20000001), 1e-9)
|
|
||||||
assert not mpf(-0.7818314824680298).ae(mpf(-0.774695868667929))
|
|
||||||
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------------
|
|
||||||
# Test basic arithmetic
|
|
||||||
#
|
|
||||||
|
|
||||||
# Test that integer arithmetic is exact
|
|
||||||
def test_aintegers():
|
|
||||||
# XXX: re-fix this so that all operations are tested with all rounding modes
|
|
||||||
random.seed(0)
|
|
||||||
for prec in [6, 10, 25, 40, 100, 250, 725]:
|
|
||||||
for rounding in ['d', 'u', 'f', 'c', 'n']:
|
|
||||||
mp.dps = prec
|
|
||||||
M = 10**(prec-2)
|
|
||||||
M2 = 10**(prec//2-2)
|
|
||||||
for i in range(10):
|
|
||||||
a = random.randint(-M, M)
|
|
||||||
b = random.randint(-M, M)
|
|
||||||
assert mpf(a, rounding=rounding) == a
|
|
||||||
assert int(mpf(a, rounding=rounding)) == a
|
|
||||||
assert int(mpf(str(a), rounding=rounding)) == a
|
|
||||||
assert mpf(a) + mpf(b) == a + b
|
|
||||||
assert mpf(a) - mpf(b) == a - b
|
|
||||||
assert -mpf(a) == -a
|
|
||||||
a = random.randint(-M2, M2)
|
|
||||||
b = random.randint(-M2, M2)
|
|
||||||
assert mpf(a) * mpf(b) == a*b
|
|
||||||
assert mpf_mul(from_int(a), from_int(b), mp.prec, rounding) == from_int(a*b)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_odd_int_bug():
|
|
||||||
assert to_int(from_int(3), round_nearest) == 3
|
|
||||||
|
|
||||||
def test_str_1000_digits():
|
|
||||||
mp.dps = 1001
|
|
||||||
# last digit may be wrong
|
|
||||||
assert str(mpf(2)**0.5)[-10:-1] == '9518488472'[:9]
|
|
||||||
assert str(pi)[-10:-1] == '2164201989'[:9]
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_str_10000_digits():
|
|
||||||
mp.dps = 10001
|
|
||||||
# last digit may be wrong
|
|
||||||
assert str(mpf(2)**0.5)[-10:-1] == '5873258351'[:9]
|
|
||||||
assert str(pi)[-10:-1] == '5256375678'[:9]
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_monitor():
|
|
||||||
f = lambda x: x**2
|
|
||||||
a = []
|
|
||||||
b = []
|
|
||||||
g = monitor(f, a.append, b.append)
|
|
||||||
assert g(3) == 9
|
|
||||||
assert g(4) == 16
|
|
||||||
assert a[0] == ((3,), {})
|
|
||||||
assert b[0] == 9
|
|
||||||
|
|
||||||
def test_nint_distance():
|
|
||||||
nint_distance(mpf(-3)) == (-3, -inf)
|
|
||||||
nint_distance(mpc(-3)) == (-3, -inf)
|
|
||||||
nint_distance(mpf(-3.1)) == (-3, -3)
|
|
||||||
nint_distance(mpf(-3.01)) == (-3, -6)
|
|
||||||
nint_distance(mpf(-3.001)) == (-3, -9)
|
|
||||||
nint_distance(mpf(-3.0001)) == (-3, -13)
|
|
||||||
nint_distance(mpf(-2.9)) == (-3, -3)
|
|
||||||
nint_distance(mpf(-2.99)) == (-3, -6)
|
|
||||||
nint_distance(mpf(-2.999)) == (-3, -9)
|
|
||||||
nint_distance(mpf(-2.9999)) == (-3, -13)
|
|
||||||
nint_distance(mpc(-3+0.1j)) == (-3, -3)
|
|
||||||
nint_distance(mpc(-3+0.01j)) == (-3, -6)
|
|
||||||
nint_distance(mpc(-3.1+0.1j)) == (-3, -3)
|
|
||||||
nint_distance(mpc(-3.01+0.01j)) == (-3, -6)
|
|
||||||
nint_distance(mpc(-3.001+0.001j)) == (-3, -9)
|
|
||||||
nint_distance(mpf(0)) == (0, -inf)
|
|
||||||
nint_distance(mpf(0.01)) == (0, -6)
|
|
||||||
nint_distance(mpf('1e-100')) == (0, -332)
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
#from mpmath.calculus import ODE_step_euler, ODE_step_rk4, odeint, arange
|
|
||||||
from mpmath import odefun, cos, sin, mpf, sinc, mp
|
|
||||||
|
|
||||||
'''
|
|
||||||
solvers = [ODE_step_euler, ODE_step_rk4]
|
|
||||||
|
|
||||||
def test_ode1():
|
|
||||||
"""
|
|
||||||
Let's solve:
|
|
||||||
|
|
||||||
x'' + w**2 * x = 0
|
|
||||||
|
|
||||||
i.e. x1 = x, x2 = x1':
|
|
||||||
|
|
||||||
x1' = x2
|
|
||||||
x2' = -x1
|
|
||||||
"""
|
|
||||||
def derivs((x1, x2), t):
|
|
||||||
return x2, -x1
|
|
||||||
|
|
||||||
for solver in solvers:
|
|
||||||
t = arange(0, 3.1415926, 0.005)
|
|
||||||
sol = odeint(derivs, (0., 1.), t, solver)
|
|
||||||
x1 = [a[0] for a in sol]
|
|
||||||
x2 = [a[1] for a in sol]
|
|
||||||
# the result is x1 = sin(t), x2 = cos(t)
|
|
||||||
# let's just check the end points for t = pi
|
|
||||||
assert abs(x1[-1]) < 1e-2
|
|
||||||
assert abs(x2[-1] - (-1)) < 1e-2
|
|
||||||
|
|
||||||
def test_ode2():
|
|
||||||
"""
|
|
||||||
Let's solve:
|
|
||||||
|
|
||||||
x' - x = 0
|
|
||||||
|
|
||||||
i.e. x = exp(x)
|
|
||||||
|
|
||||||
"""
|
|
||||||
def derivs((x), t):
|
|
||||||
return x
|
|
||||||
|
|
||||||
for solver in solvers:
|
|
||||||
t = arange(0, 1, 1e-3)
|
|
||||||
sol = odeint(derivs, (1.,), t, solver)
|
|
||||||
x = [a[0] for a in sol]
|
|
||||||
# the result is x = exp(t)
|
|
||||||
# let's just check the end point for t = 1, i.e. x = e
|
|
||||||
assert abs(x[-1] - 2.718281828) < 1e-2
|
|
||||||
'''
|
|
||||||
|
|
||||||
def test_odefun_rational():
|
|
||||||
mp.dps = 15
|
|
||||||
# A rational function
|
|
||||||
f = lambda t: 1/(1+mpf(t)**2)
|
|
||||||
g = odefun(lambda x, y: [-2*x*y[0]**2], 0, [f(0)])
|
|
||||||
assert f(2).ae(g(2)[0])
|
|
||||||
|
|
||||||
def test_odefun_sinc_large():
|
|
||||||
mp.dps = 15
|
|
||||||
# Sinc function; test for large x
|
|
||||||
f = sinc
|
|
||||||
g = odefun(lambda x, y: [(cos(x)-y[0])/x], 1, [f(1)], tol=0.01, degree=5)
|
|
||||||
assert abs(f(100) - g(100)[0])/f(100) < 0.01
|
|
||||||
|
|
||||||
def test_odefun_harmonic():
|
|
||||||
mp.dps = 15
|
|
||||||
# Harmonic oscillator
|
|
||||||
f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
|
|
||||||
for x in [0, 1, 2.5, 8, 3.7]: # we go back to 3.7 to check caching
|
|
||||||
c, s = f(x)
|
|
||||||
assert c.ae(cos(x))
|
|
||||||
assert s.ae(sin(x))
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def pickler(obj):
|
|
||||||
fn = tempfile.mktemp()
|
|
||||||
|
|
||||||
f = open(fn, 'wb')
|
|
||||||
pickle.dump(obj, f)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
f = open(fn, 'rb')
|
|
||||||
obj2 = pickle.load(f)
|
|
||||||
f.close()
|
|
||||||
os.remove(fn)
|
|
||||||
|
|
||||||
return obj2
|
|
||||||
|
|
||||||
def test_pickle():
|
|
||||||
|
|
||||||
obj = mpf('0.5')
|
|
||||||
assert obj == pickler(obj)
|
|
||||||
|
|
||||||
obj = mpc('0.5','0.2')
|
|
||||||
assert obj == pickler(obj)
|
|
||||||
|
|
@ -1,155 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
from mpmath.libmp import *
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
def test_fractional_pow():
|
|
||||||
assert mpf(16) ** 2.5 == 1024
|
|
||||||
assert mpf(64) ** 0.5 == 8
|
|
||||||
assert mpf(64) ** -0.5 == 0.125
|
|
||||||
assert mpf(16) ** -2.5 == 0.0009765625
|
|
||||||
assert (mpf(10) ** 0.5).ae(3.1622776601683791)
|
|
||||||
assert (mpf(10) ** 2.5).ae(316.2277660168379)
|
|
||||||
assert (mpf(10) ** -0.5).ae(0.31622776601683794)
|
|
||||||
assert (mpf(10) ** -2.5).ae(0.0031622776601683794)
|
|
||||||
assert (mpf(10) ** 0.3).ae(1.9952623149688795)
|
|
||||||
assert (mpf(10) ** -0.3).ae(0.50118723362727224)
|
|
||||||
|
|
||||||
def test_pow_integer_direction():
|
|
||||||
"""
|
|
||||||
Test that inexact integer powers are rounded in the right
|
|
||||||
direction.
|
|
||||||
"""
|
|
||||||
random.seed(1234)
|
|
||||||
for prec in [10, 53, 200]:
|
|
||||||
for i in range(50):
|
|
||||||
a = random.randint(1<<(prec-1), 1<<prec)
|
|
||||||
b = random.randint(2, 100)
|
|
||||||
ab = a**b
|
|
||||||
# note: could actually be exact, but that's very unlikely!
|
|
||||||
assert to_int(mpf_pow(from_int(a), from_int(b), prec, round_down)) < ab
|
|
||||||
assert to_int(mpf_pow(from_int(a), from_int(b), prec, round_up)) > ab
|
|
||||||
|
|
||||||
|
|
||||||
def test_pow_epsilon_rounding():
|
|
||||||
"""
|
|
||||||
Stress test directed rounding for powers with integer exponents.
|
|
||||||
Basically, we look at the following cases:
|
|
||||||
|
|
||||||
>>> 1.0001 ** -5
|
|
||||||
0.99950014996500702
|
|
||||||
>>> 0.9999 ** -5
|
|
||||||
1.000500150035007
|
|
||||||
>>> (-1.0001) ** -5
|
|
||||||
-0.99950014996500702
|
|
||||||
>>> (-0.9999) ** -5
|
|
||||||
-1.000500150035007
|
|
||||||
|
|
||||||
>>> 1.0001 ** -6
|
|
||||||
0.99940020994401269
|
|
||||||
>>> 0.9999 ** -6
|
|
||||||
1.0006002100560125
|
|
||||||
>>> (-1.0001) ** -6
|
|
||||||
0.99940020994401269
|
|
||||||
>>> (-0.9999) ** -6
|
|
||||||
1.0006002100560125
|
|
||||||
|
|
||||||
etc.
|
|
||||||
|
|
||||||
We run the tests with values a very small epsilon away from 1:
|
|
||||||
small enough that the result is indistinguishable from 1 when
|
|
||||||
rounded to nearest at the output precision. We check that the
|
|
||||||
result is not erroneously rounded to 1 in cases where the
|
|
||||||
rounding should be done strictly away from 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def powr(x, n, r):
|
|
||||||
return make_mpf(mpf_pow_int(x._mpf_, n, mp.prec, r))
|
|
||||||
|
|
||||||
for (inprec, outprec) in [(100, 20), (5000, 3000)]:
|
|
||||||
|
|
||||||
mp.prec = inprec
|
|
||||||
|
|
||||||
pos10001 = mpf(1) + mpf(2)**(-inprec+5)
|
|
||||||
pos09999 = mpf(1) - mpf(2)**(-inprec+5)
|
|
||||||
neg10001 = -pos10001
|
|
||||||
neg09999 = -pos09999
|
|
||||||
|
|
||||||
mp.prec = outprec
|
|
||||||
r = round_up
|
|
||||||
assert powr(pos10001, 5, r) > 1
|
|
||||||
assert powr(pos09999, 5, r) == 1
|
|
||||||
assert powr(neg10001, 5, r) < -1
|
|
||||||
assert powr(neg09999, 5, r) == -1
|
|
||||||
assert powr(pos10001, 6, r) > 1
|
|
||||||
assert powr(pos09999, 6, r) == 1
|
|
||||||
assert powr(neg10001, 6, r) > 1
|
|
||||||
assert powr(neg09999, 6, r) == 1
|
|
||||||
|
|
||||||
assert powr(pos10001, -5, r) == 1
|
|
||||||
assert powr(pos09999, -5, r) > 1
|
|
||||||
assert powr(neg10001, -5, r) == -1
|
|
||||||
assert powr(neg09999, -5, r) < -1
|
|
||||||
assert powr(pos10001, -6, r) == 1
|
|
||||||
assert powr(pos09999, -6, r) > 1
|
|
||||||
assert powr(neg10001, -6, r) == 1
|
|
||||||
assert powr(neg09999, -6, r) > 1
|
|
||||||
|
|
||||||
r = round_down
|
|
||||||
assert powr(pos10001, 5, r) == 1
|
|
||||||
assert powr(pos09999, 5, r) < 1
|
|
||||||
assert powr(neg10001, 5, r) == -1
|
|
||||||
assert powr(neg09999, 5, r) > -1
|
|
||||||
assert powr(pos10001, 6, r) == 1
|
|
||||||
assert powr(pos09999, 6, r) < 1
|
|
||||||
assert powr(neg10001, 6, r) == 1
|
|
||||||
assert powr(neg09999, 6, r) < 1
|
|
||||||
|
|
||||||
assert powr(pos10001, -5, r) < 1
|
|
||||||
assert powr(pos09999, -5, r) == 1
|
|
||||||
assert powr(neg10001, -5, r) > -1
|
|
||||||
assert powr(neg09999, -5, r) == -1
|
|
||||||
assert powr(pos10001, -6, r) < 1
|
|
||||||
assert powr(pos09999, -6, r) == 1
|
|
||||||
assert powr(neg10001, -6, r) < 1
|
|
||||||
assert powr(neg09999, -6, r) == 1
|
|
||||||
|
|
||||||
r = round_ceiling
|
|
||||||
assert powr(pos10001, 5, r) > 1
|
|
||||||
assert powr(pos09999, 5, r) == 1
|
|
||||||
assert powr(neg10001, 5, r) == -1
|
|
||||||
assert powr(neg09999, 5, r) > -1
|
|
||||||
assert powr(pos10001, 6, r) > 1
|
|
||||||
assert powr(pos09999, 6, r) == 1
|
|
||||||
assert powr(neg10001, 6, r) > 1
|
|
||||||
assert powr(neg09999, 6, r) == 1
|
|
||||||
|
|
||||||
assert powr(pos10001, -5, r) == 1
|
|
||||||
assert powr(pos09999, -5, r) > 1
|
|
||||||
assert powr(neg10001, -5, r) > -1
|
|
||||||
assert powr(neg09999, -5, r) == -1
|
|
||||||
assert powr(pos10001, -6, r) == 1
|
|
||||||
assert powr(pos09999, -6, r) > 1
|
|
||||||
assert powr(neg10001, -6, r) == 1
|
|
||||||
assert powr(neg09999, -6, r) > 1
|
|
||||||
|
|
||||||
r = round_floor
|
|
||||||
assert powr(pos10001, 5, r) == 1
|
|
||||||
assert powr(pos09999, 5, r) < 1
|
|
||||||
assert powr(neg10001, 5, r) < -1
|
|
||||||
assert powr(neg09999, 5, r) == -1
|
|
||||||
assert powr(pos10001, 6, r) == 1
|
|
||||||
assert powr(pos09999, 6, r) < 1
|
|
||||||
assert powr(neg10001, 6, r) == 1
|
|
||||||
assert powr(neg09999, 6, r) < 1
|
|
||||||
|
|
||||||
assert powr(pos10001, -5, r) < 1
|
|
||||||
assert powr(pos09999, -5, r) == 1
|
|
||||||
assert powr(neg10001, -5, r) == -1
|
|
||||||
assert powr(neg09999, -5, r) < -1
|
|
||||||
assert powr(pos10001, -6, r) < 1
|
|
||||||
assert powr(pos09999, -6, r) == 1
|
|
||||||
assert powr(neg10001, -6, r) < 1
|
|
||||||
assert powr(neg09999, -6, r) == 1
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
@ -1,85 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def ae(a, b):
|
|
||||||
return abs(a-b) < 10**(-mp.dps+5)
|
|
||||||
|
|
||||||
def test_basic_integrals():
|
|
||||||
for prec in [15, 30, 100]:
|
|
||||||
mp.dps = prec
|
|
||||||
assert ae(quadts(lambda x: x**3 - 3*x**2, [-2, 4]), -12)
|
|
||||||
assert ae(quadgl(lambda x: x**3 - 3*x**2, [-2, 4]), -12)
|
|
||||||
assert ae(quadts(sin, [0, pi]), 2)
|
|
||||||
assert ae(quadts(sin, [0, 2*pi]), 0)
|
|
||||||
assert ae(quadts(exp, [-inf, -1]), 1/e)
|
|
||||||
assert ae(quadts(lambda x: exp(-x), [0, inf]), 1)
|
|
||||||
assert ae(quadts(lambda x: exp(-x*x), [-inf, inf]), sqrt(pi))
|
|
||||||
assert ae(quadts(lambda x: 1/(1+x*x), [-1, 1]), pi/2)
|
|
||||||
assert ae(quadts(lambda x: 1/(1+x*x), [-inf, inf]), pi)
|
|
||||||
assert ae(quadts(lambda x: 2*sqrt(1-x*x), [-1, 1]), pi)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
def test_quad_symmetry():
|
|
||||||
assert quadts(sin, [-1, 1]) == 0
|
|
||||||
assert quadgl(sin, [-1, 1]) == 0
|
|
||||||
|
|
||||||
def test_quadgl_linear():
|
|
||||||
assert quadgl(lambda x: x, [0, 1], maxdegree=1).ae(0.5)
|
|
||||||
|
|
||||||
def test_complex_integration():
|
|
||||||
assert quadts(lambda x: x, [0, 1+j]).ae(j)
|
|
||||||
|
|
||||||
def test_quadosc():
|
|
||||||
mp.dps = 15
|
|
||||||
assert quadosc(lambda x: sin(x)/x, [0, inf], period=2*pi).ae(pi/2)
|
|
||||||
|
|
||||||
# Double integrals
|
|
||||||
def test_double_trivial():
|
|
||||||
assert ae(quadts(lambda x, y: x, [0, 1], [0, 1]), 0.5)
|
|
||||||
assert ae(quadts(lambda x, y: x, [-1, 1], [-1, 1]), 0.0)
|
|
||||||
|
|
||||||
def test_double_1():
|
|
||||||
assert ae(quadts(lambda x, y: cos(x+y/2), [-pi/2, pi/2], [0, pi]), 4)
|
|
||||||
|
|
||||||
def test_double_2():
|
|
||||||
assert ae(quadts(lambda x, y: (x-1)/((1-x*y)*log(x*y)), [0, 1], [0, 1]), euler)
|
|
||||||
|
|
||||||
def test_double_3():
|
|
||||||
assert ae(quadts(lambda x, y: 1/sqrt(1+x*x+y*y), [-1, 1], [-1, 1]), 4*log(2+sqrt(3))-2*pi/3)
|
|
||||||
|
|
||||||
def test_double_4():
|
|
||||||
assert ae(quadts(lambda x, y: 1/(1-x*x * y*y), [0, 1], [0, 1]), pi**2 / 8)
|
|
||||||
|
|
||||||
def test_double_5():
|
|
||||||
assert ae(quadts(lambda x, y: 1/(1-x*y), [0, 1], [0, 1]), pi**2 / 6)
|
|
||||||
|
|
||||||
def test_double_6():
|
|
||||||
assert ae(quadts(lambda x, y: exp(-(x+y)), [0, inf], [0, inf]), 1)
|
|
||||||
|
|
||||||
# fails
|
|
||||||
def xtest_double_7():
|
|
||||||
assert ae(quadts(lambda x, y: exp(-x*x-y*y), [-inf, inf], [-inf, inf]), pi)
|
|
||||||
|
|
||||||
|
|
||||||
# Test integrals from "Experimentation in Mathematics" by Borwein,
|
|
||||||
# Bailey & Girgensohn
|
|
||||||
def test_expmath_integrals():
|
|
||||||
for prec in [15, 30, 50]:
|
|
||||||
mp.dps = prec
|
|
||||||
assert ae(quadts(lambda x: x/sinh(x), [0, inf]), pi**2 / 4)
|
|
||||||
assert ae(quadts(lambda x: log(x)**2 / (1+x**2), [0, inf]), pi**3 / 8)
|
|
||||||
assert ae(quadts(lambda x: (1+x**2)/(1+x**4), [0, inf]), pi/sqrt(2))
|
|
||||||
assert ae(quadts(lambda x: log(x)/cosh(x)**2, [0, inf]), log(pi)-2*log(2)-euler)
|
|
||||||
assert ae(quadts(lambda x: log(1+x**3)/(1-x+x**2), [0, inf]), 2*pi*log(3)/sqrt(3))
|
|
||||||
assert ae(quadts(lambda x: log(x)**2 / (x**2+x+1), [0, 1]), 8*pi**3 / (81*sqrt(3)))
|
|
||||||
assert ae(quadts(lambda x: log(cos(x))**2, [0, pi/2]), pi/2 * (log(2)**2+pi**2/12))
|
|
||||||
assert ae(quadts(lambda x: x**2 / sin(x)**2, [0, pi/2]), pi*log(2))
|
|
||||||
assert ae(quadts(lambda x: x**2/sqrt(exp(x)-1), [0, inf]), 4*pi*(log(2)**2 + pi**2/12))
|
|
||||||
assert ae(quadts(lambda x: x*exp(-x)*sqrt(1-exp(-2*x)), [0, inf]), pi*(1+2*log(2))/8)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
# Do not reach full accuracy
|
|
||||||
def xtest_expmath_fail():
|
|
||||||
assert ae(quadts(lambda x: sqrt(tan(x)), [0, pi/2]), pi*sqrt(2)/2)
|
|
||||||
assert ae(quadts(lambda x: atan(x)/(x*sqrt(1-x**2)), [0, 1]), pi*log(1+sqrt(2))/2)
|
|
||||||
assert ae(quadts(lambda x: log(1+x**2)/x**2, [0, 1]), pi/2-log(2))
|
|
||||||
assert ae(quadts(lambda x: x**2/((1+x**4)*sqrt(1-x**4)), [0, 1]), pi/8)
|
|
||||||
|
|
@ -1,75 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
from mpmath.calculus.optimization import Secant, Muller, Bisection, Illinois, \
|
|
||||||
Pegasus, Anderson, Ridder, ANewton, Newton, MNewton, MDNewton
|
|
||||||
|
|
||||||
def test_findroot():
|
|
||||||
# old tests, assuming secant
|
|
||||||
mp.dps = 15
|
|
||||||
assert findroot(lambda x: 4*x-3, mpf(5)).ae(0.75)
|
|
||||||
assert findroot(sin, mpf(3)).ae(pi)
|
|
||||||
assert findroot(sin, (mpf(3), mpf(3.14))).ae(pi)
|
|
||||||
assert findroot(lambda x: x*x+1, mpc(2+2j)).ae(1j)
|
|
||||||
# test all solvers with 1 starting point
|
|
||||||
f = lambda x: cos(x)
|
|
||||||
for solver in [Newton, Secant, MNewton, Muller, ANewton]:
|
|
||||||
x = findroot(f, 2., solver=solver)
|
|
||||||
assert abs(f(x)) < eps
|
|
||||||
# test all solvers with interval of 2 points
|
|
||||||
for solver in [Secant, Muller, Bisection, Illinois, Pegasus, Anderson,
|
|
||||||
Ridder]:
|
|
||||||
x = findroot(f, (1., 2.), solver=solver)
|
|
||||||
assert abs(f(x)) < eps
|
|
||||||
# test types
|
|
||||||
f = lambda x: (x - 2)**2
|
|
||||||
|
|
||||||
#assert isinstance(findroot(f, 1, force_type=mpf, tol=1e-10), mpf)
|
|
||||||
#assert isinstance(findroot(f, 1., force_type=None, tol=1e-10), float)
|
|
||||||
#assert isinstance(findroot(f, 1, force_type=complex, tol=1e-10), complex)
|
|
||||||
assert isinstance(fp.findroot(f, 1, tol=1e-10), float)
|
|
||||||
assert isinstance(fp.findroot(f, 1+0j, tol=1e-10), complex)
|
|
||||||
|
|
||||||
def test_mnewton():
|
|
||||||
f = lambda x: polyval([1,3,3,1],x)
|
|
||||||
x = findroot(f, -0.9, solver='mnewton')
|
|
||||||
assert abs(f(x)) < eps
|
|
||||||
|
|
||||||
def test_anewton():
|
|
||||||
f = lambda x: (x - 2)**100
|
|
||||||
x = findroot(f, 1., solver=ANewton)
|
|
||||||
assert abs(f(x)) < eps
|
|
||||||
|
|
||||||
def test_muller():
|
|
||||||
f = lambda x: (2 + x)**3 + 2
|
|
||||||
x = findroot(f, 1., solver=Muller)
|
|
||||||
assert abs(f(x)) < eps
|
|
||||||
|
|
||||||
def test_multiplicity():
|
|
||||||
for i in xrange(1, 5):
|
|
||||||
assert multiplicity(lambda x: (x - 1)**i, 1) == i
|
|
||||||
assert multiplicity(lambda x: x**2, 1) == 0
|
|
||||||
|
|
||||||
def test_multidimensional():
|
|
||||||
def f(*x):
|
|
||||||
return [3*x[0]**2-2*x[1]**2-1, x[0]**2-2*x[0]+x[1]**2+2*x[1]-8]
|
|
||||||
assert mnorm(jacobian(f, (1,-2)) - matrix([[6,8],[0,-2]]),1) < 1.e-7
|
|
||||||
for x, error in MDNewton(mp, f, (1,-2), verbose=0,
|
|
||||||
norm=lambda x: norm(x, inf)):
|
|
||||||
pass
|
|
||||||
assert norm(f(*x), 2) < 1e-14
|
|
||||||
# The Chinese mathematician Zhu Shijie was the very first to solve this
|
|
||||||
# nonlinear system 700 years ago
|
|
||||||
f1 = lambda x, y: -x + 2*y
|
|
||||||
f2 = lambda x, y: (x**2 + x*(y**2 - 2) - 4*y) / (x + 4)
|
|
||||||
f3 = lambda x, y: sqrt(x**2 + y**2)
|
|
||||||
def f(x, y):
|
|
||||||
f1x = f1(x, y)
|
|
||||||
return (f2(x, y) - f1x, f3(x, y) - f1x)
|
|
||||||
x = findroot(f, (10, 10))
|
|
||||||
assert [int(round(i)) for i in x] == [3, 4]
|
|
||||||
|
|
||||||
def test_trivial():
|
|
||||||
assert findroot(lambda x: 0, 1) == 1
|
|
||||||
assert findroot(lambda x: x, 0) == 0
|
|
||||||
#assert findroot(lambda x, y: x + y, (1, -1)) == (1, -1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,112 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_special():
|
|
||||||
assert inf == inf
|
|
||||||
assert inf != -inf
|
|
||||||
assert -inf == -inf
|
|
||||||
assert inf != nan
|
|
||||||
assert nan != nan
|
|
||||||
assert isnan(nan)
|
|
||||||
assert --inf == inf
|
|
||||||
assert abs(inf) == inf
|
|
||||||
assert abs(-inf) == inf
|
|
||||||
assert abs(nan) != abs(nan)
|
|
||||||
|
|
||||||
assert isnan(inf - inf)
|
|
||||||
assert isnan(inf + (-inf))
|
|
||||||
assert isnan(-inf - (-inf))
|
|
||||||
|
|
||||||
assert isnan(inf + nan)
|
|
||||||
assert isnan(-inf + nan)
|
|
||||||
|
|
||||||
assert mpf(2) + inf == inf
|
|
||||||
assert 2 + inf == inf
|
|
||||||
assert mpf(2) - inf == -inf
|
|
||||||
assert 2 - inf == -inf
|
|
||||||
|
|
||||||
assert inf > 3
|
|
||||||
assert 3 < inf
|
|
||||||
assert 3 > -inf
|
|
||||||
assert -inf < 3
|
|
||||||
assert inf > mpf(3)
|
|
||||||
assert mpf(3) < inf
|
|
||||||
assert mpf(3) > -inf
|
|
||||||
assert -inf < mpf(3)
|
|
||||||
|
|
||||||
assert not (nan < 3)
|
|
||||||
assert not (nan > 3)
|
|
||||||
|
|
||||||
assert isnan(inf * 0)
|
|
||||||
assert isnan(-inf * 0)
|
|
||||||
assert inf * 3 == inf
|
|
||||||
assert inf * -3 == -inf
|
|
||||||
assert -inf * 3 == -inf
|
|
||||||
assert -inf * -3 == inf
|
|
||||||
assert inf * inf == inf
|
|
||||||
assert -inf * -inf == inf
|
|
||||||
|
|
||||||
assert isnan(nan / 3)
|
|
||||||
assert inf / -3 == -inf
|
|
||||||
assert inf / 3 == inf
|
|
||||||
assert 3 / inf == 0
|
|
||||||
assert -3 / inf == 0
|
|
||||||
assert 0 / inf == 0
|
|
||||||
assert isnan(inf / inf)
|
|
||||||
assert isnan(inf / -inf)
|
|
||||||
assert isnan(inf / nan)
|
|
||||||
|
|
||||||
assert mpf('inf') == mpf('+inf') == inf
|
|
||||||
assert mpf('-inf') == -inf
|
|
||||||
assert isnan(mpf('nan'))
|
|
||||||
|
|
||||||
assert isinf(inf)
|
|
||||||
assert isinf(-inf)
|
|
||||||
assert not isinf(mpf(0))
|
|
||||||
assert not isinf(nan)
|
|
||||||
|
|
||||||
def test_special_powers():
|
|
||||||
assert inf**3 == inf
|
|
||||||
assert isnan(inf**0)
|
|
||||||
assert inf**-3 == 0
|
|
||||||
assert (-inf)**2 == inf
|
|
||||||
assert (-inf)**3 == -inf
|
|
||||||
assert isnan((-inf)**0)
|
|
||||||
assert (-inf)**-2 == 0
|
|
||||||
assert (-inf)**-3 == 0
|
|
||||||
assert isnan(nan**5)
|
|
||||||
assert isnan(nan**0)
|
|
||||||
|
|
||||||
def test_functions_special():
|
|
||||||
assert exp(inf) == inf
|
|
||||||
assert exp(-inf) == 0
|
|
||||||
assert isnan(exp(nan))
|
|
||||||
assert log(inf) == inf
|
|
||||||
assert isnan(sin(inf))
|
|
||||||
assert isnan(sin(nan))
|
|
||||||
assert atan(inf).ae(pi/2)
|
|
||||||
assert atan(-inf).ae(-pi/2)
|
|
||||||
assert isnan(sqrt(nan))
|
|
||||||
assert sqrt(inf) == inf
|
|
||||||
|
|
||||||
def test_convert_special():
|
|
||||||
float_inf = 1e300 * 1e300
|
|
||||||
float_ninf = -float_inf
|
|
||||||
float_nan = float_inf/float_ninf
|
|
||||||
assert mpf(3) * float_inf == inf
|
|
||||||
assert mpf(3) * float_ninf == -inf
|
|
||||||
assert isnan(mpf(3) * float_nan)
|
|
||||||
assert not (mpf(3) < float_nan)
|
|
||||||
assert not (mpf(3) > float_nan)
|
|
||||||
assert not (mpf(3) <= float_nan)
|
|
||||||
assert not (mpf(3) >= float_nan)
|
|
||||||
assert float(mpf('1e1000')) == float_inf
|
|
||||||
assert float(mpf('-1e1000')) == float_ninf
|
|
||||||
assert float(mpf('1e100000000000000000')) == float_inf
|
|
||||||
assert float(mpf('-1e100000000000000000')) == float_ninf
|
|
||||||
assert float(mpf('1e-100000000000000000')) == 0.0
|
|
||||||
|
|
||||||
def test_div_bug():
|
|
||||||
assert isnan(nan/1)
|
|
||||||
assert isnan(nan/2)
|
|
||||||
assert inf/2 == inf
|
|
||||||
assert (-inf)/2 == -inf
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
from mpmath import nstr, matrix, inf
|
|
||||||
|
|
||||||
def test_nstr():
|
|
||||||
m = matrix([[0.75, 0.190940654, -0.0299195971],
|
|
||||||
[0.190940654, 0.65625, 0.205663228],
|
|
||||||
[-0.0299195971, 0.205663228, 0.64453125e-20]])
|
|
||||||
assert nstr(m, 4, min_fixed=-inf) == \
|
|
||||||
'''[ 0.75 0.1909 -0.02992]
|
|
||||||
[ 0.1909 0.6563 0.2057]
|
|
||||||
[-0.02992 0.2057 0.000000000000000000006445]'''
|
|
||||||
assert nstr(m, 4) == \
|
|
||||||
'''[ 0.75 0.1909 -0.02992]
|
|
||||||
[ 0.1909 0.6563 0.2057]
|
|
||||||
[-0.02992 0.2057 6.445e-21]'''
|
|
||||||
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_sumem():
|
|
||||||
mp.dps = 15
|
|
||||||
assert sumem(lambda k: 1/k**2.5, [50, 100]).ae(0.0012524505324784962)
|
|
||||||
assert sumem(lambda k: k**4 + 3*k + 1, [10, 100]).ae(2050333103)
|
|
||||||
|
|
||||||
def test_nsum():
|
|
||||||
mp.dps = 15
|
|
||||||
assert nsum(lambda x: x**2, [1, 3]) == 14
|
|
||||||
assert nsum(lambda k: 1/factorial(k), [0, inf]).ae(e)
|
|
||||||
assert nsum(lambda k: (-1)**(k+1) / k, [1, inf]).ae(log(2))
|
|
||||||
assert nsum(lambda k: (-1)**(k+1) / k**2, [1, inf]).ae(pi**2 / 12)
|
|
||||||
assert nsum(lambda k: (-1)**k / log(k), [2, inf]).ae(0.9242998972229388)
|
|
||||||
assert nsum(lambda k: 1/k**2, [1, inf]).ae(pi**2 / 6)
|
|
||||||
assert nsum(lambda k: 2**k/fac(k), [0, inf]).ae(exp(2))
|
|
||||||
assert nsum(lambda k: 1/k**2, [4, inf], method='e').ae(0.2838229557371153)
|
|
||||||
|
|
||||||
def test_nprod():
|
|
||||||
mp.dps = 15
|
|
||||||
assert nprod(lambda k: exp(1/k**2), [1,inf], method='r').ae(exp(pi**2/6))
|
|
||||||
assert nprod(lambda x: x**2, [1, 3]) == 36
|
|
||||||
|
|
||||||
def test_fsum():
|
|
||||||
mp.dps = 15
|
|
||||||
assert fsum([]) == 0
|
|
||||||
assert fsum([-4]) == -4
|
|
||||||
assert fsum([2,3]) == 5
|
|
||||||
assert fsum([1e-100,1]) == 1
|
|
||||||
assert fsum([1,1e-100]) == 1
|
|
||||||
assert fsum([1e100,1]) == 1e100
|
|
||||||
assert fsum([1,1e100]) == 1e100
|
|
||||||
assert fsum([1e-100,0]) == 1e-100
|
|
||||||
assert fsum([1e-100,1e100,1e-100]) == 1e100
|
|
||||||
assert fsum([2,1+1j,1]) == 4+1j
|
|
||||||
assert fsum([1,mpi(2,3)]) == mpi(3,4)
|
|
||||||
assert fsum([2,inf,3]) == inf
|
|
||||||
assert fsum([2,-1], absolute=1) == 3
|
|
||||||
assert fsum([2,-1], squared=1) == 5
|
|
||||||
assert fsum([1,1+j], squared=1) == 1+2j
|
|
||||||
assert fsum([1,3+4j], absolute=1) == 6
|
|
||||||
assert fsum([1,2+3j], absolute=1, squared=1) == 14
|
|
||||||
assert isnan(fsum([inf,-inf]))
|
|
||||||
assert fsum([inf,-inf], absolute=1) == inf
|
|
||||||
assert fsum([inf,-inf], squared=1) == inf
|
|
||||||
assert fsum([inf,-inf], absolute=1, squared=1) == inf
|
|
||||||
|
|
||||||
def test_fprod():
|
|
||||||
mp.dps = 15
|
|
||||||
assert fprod([]) == 1
|
|
||||||
assert fprod([2,3]) == 6
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
from mpmath import *
|
|
||||||
from mpmath.libmp import *
|
|
||||||
|
|
||||||
def test_trig_misc_hard():
|
|
||||||
mp.prec = 53
|
|
||||||
# Worst-case input for an IEEE double, from a paper by Kahan
|
|
||||||
x = ldexp(6381956970095103,797)
|
|
||||||
assert cos(x) == mpf('-4.6871659242546277e-19')
|
|
||||||
assert sin(x) == 1
|
|
||||||
|
|
||||||
mp.prec = 150
|
|
||||||
a = mpf(10**50)
|
|
||||||
mp.prec = 53
|
|
||||||
assert sin(a).ae(-0.7896724934293100827)
|
|
||||||
assert cos(a).ae(-0.6135286082336635622)
|
|
||||||
|
|
||||||
# Check relative accuracy close to x = zero
|
|
||||||
assert sin(1e-100) == 1e-100 # when rounding to nearest
|
|
||||||
assert sin(1e-6).ae(9.999999999998333e-007, rel_eps=2e-15, abs_eps=0)
|
|
||||||
assert sin(1e-6j).ae(1.0000000000001666e-006j, rel_eps=2e-15, abs_eps=0)
|
|
||||||
assert sin(-1e-6j).ae(-1.0000000000001666e-006j, rel_eps=2e-15, abs_eps=0)
|
|
||||||
assert cos(1e-100) == 1
|
|
||||||
assert cos(1e-6).ae(0.9999999999995)
|
|
||||||
assert cos(-1e-6j).ae(1.0000000000005)
|
|
||||||
assert tan(1e-100) == 1e-100
|
|
||||||
assert tan(1e-6).ae(1.0000000000003335e-006, rel_eps=2e-15, abs_eps=0)
|
|
||||||
assert tan(1e-6j).ae(9.9999999999966644e-007j, rel_eps=2e-15, abs_eps=0)
|
|
||||||
assert tan(-1e-6j).ae(-9.9999999999966644e-007j, rel_eps=2e-15, abs_eps=0)
|
|
||||||
|
|
||||||
def test_trig_near_zero():
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
for r in [round_nearest, round_down, round_up, round_floor, round_ceiling]:
|
|
||||||
assert sin(0, rounding=r) == 0
|
|
||||||
assert cos(0, rounding=r) == 1
|
|
||||||
|
|
||||||
a = mpf('1e-100')
|
|
||||||
b = mpf('-1e-100')
|
|
||||||
|
|
||||||
assert sin(a, rounding=round_nearest) == a
|
|
||||||
assert sin(a, rounding=round_down) < a
|
|
||||||
assert sin(a, rounding=round_floor) < a
|
|
||||||
assert sin(a, rounding=round_up) >= a
|
|
||||||
assert sin(a, rounding=round_ceiling) >= a
|
|
||||||
assert sin(b, rounding=round_nearest) == b
|
|
||||||
assert sin(b, rounding=round_down) > b
|
|
||||||
assert sin(b, rounding=round_floor) <= b
|
|
||||||
assert sin(b, rounding=round_up) <= b
|
|
||||||
assert sin(b, rounding=round_ceiling) > b
|
|
||||||
|
|
||||||
assert cos(a, rounding=round_nearest) == 1
|
|
||||||
assert cos(a, rounding=round_down) < 1
|
|
||||||
assert cos(a, rounding=round_floor) < 1
|
|
||||||
assert cos(a, rounding=round_up) == 1
|
|
||||||
assert cos(a, rounding=round_ceiling) == 1
|
|
||||||
assert cos(b, rounding=round_nearest) == 1
|
|
||||||
assert cos(b, rounding=round_down) < 1
|
|
||||||
assert cos(b, rounding=round_floor) < 1
|
|
||||||
assert cos(b, rounding=round_up) == 1
|
|
||||||
assert cos(b, rounding=round_ceiling) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_trig_near_n_pi():
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
a = [n*pi for n in [1, 2, 6, 11, 100, 1001, 10000, 100001]]
|
|
||||||
mp.dps = 135
|
|
||||||
a.append(10**100 * pi)
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
assert sin(a[0]) == mpf('1.2246467991473531772e-16')
|
|
||||||
assert sin(a[1]) == mpf('-2.4492935982947063545e-16')
|
|
||||||
assert sin(a[2]) == mpf('-7.3478807948841190634e-16')
|
|
||||||
assert sin(a[3]) == mpf('4.8998251578625894243e-15')
|
|
||||||
assert sin(a[4]) == mpf('1.9643867237284719452e-15')
|
|
||||||
assert sin(a[5]) == mpf('-8.8632615209684813458e-15')
|
|
||||||
assert sin(a[6]) == mpf('-4.8568235395684898392e-13')
|
|
||||||
assert sin(a[7]) == mpf('3.9087342299491231029e-11')
|
|
||||||
assert sin(a[8]) == mpf('-1.369235466754566993528e-36')
|
|
||||||
|
|
||||||
r = round_nearest
|
|
||||||
assert cos(a[0], rounding=r) == -1
|
|
||||||
assert cos(a[1], rounding=r) == 1
|
|
||||||
assert cos(a[2], rounding=r) == 1
|
|
||||||
assert cos(a[3], rounding=r) == -1
|
|
||||||
assert cos(a[4], rounding=r) == 1
|
|
||||||
assert cos(a[5], rounding=r) == -1
|
|
||||||
assert cos(a[6], rounding=r) == 1
|
|
||||||
assert cos(a[7], rounding=r) == -1
|
|
||||||
assert cos(a[8], rounding=r) == 1
|
|
||||||
|
|
||||||
r = round_up
|
|
||||||
assert cos(a[0], rounding=r) == -1
|
|
||||||
assert cos(a[1], rounding=r) == 1
|
|
||||||
assert cos(a[2], rounding=r) == 1
|
|
||||||
assert cos(a[3], rounding=r) == -1
|
|
||||||
assert cos(a[4], rounding=r) == 1
|
|
||||||
assert cos(a[5], rounding=r) == -1
|
|
||||||
assert cos(a[6], rounding=r) == 1
|
|
||||||
assert cos(a[7], rounding=r) == -1
|
|
||||||
assert cos(a[8], rounding=r) == 1
|
|
||||||
|
|
||||||
r = round_down
|
|
||||||
assert cos(a[0], rounding=r) > -1
|
|
||||||
assert cos(a[1], rounding=r) < 1
|
|
||||||
assert cos(a[2], rounding=r) < 1
|
|
||||||
assert cos(a[3], rounding=r) > -1
|
|
||||||
assert cos(a[4], rounding=r) < 1
|
|
||||||
assert cos(a[5], rounding=r) > -1
|
|
||||||
assert cos(a[6], rounding=r) < 1
|
|
||||||
assert cos(a[7], rounding=r) > -1
|
|
||||||
assert cos(a[8], rounding=r) < 1
|
|
||||||
|
|
||||||
r = round_floor
|
|
||||||
assert cos(a[0], rounding=r) == -1
|
|
||||||
assert cos(a[1], rounding=r) < 1
|
|
||||||
assert cos(a[2], rounding=r) < 1
|
|
||||||
assert cos(a[3], rounding=r) == -1
|
|
||||||
assert cos(a[4], rounding=r) < 1
|
|
||||||
assert cos(a[5], rounding=r) == -1
|
|
||||||
assert cos(a[6], rounding=r) < 1
|
|
||||||
assert cos(a[7], rounding=r) == -1
|
|
||||||
assert cos(a[8], rounding=r) < 1
|
|
||||||
|
|
||||||
r = round_ceiling
|
|
||||||
assert cos(a[0], rounding=r) > -1
|
|
||||||
assert cos(a[1], rounding=r) == 1
|
|
||||||
assert cos(a[2], rounding=r) == 1
|
|
||||||
assert cos(a[3], rounding=r) > -1
|
|
||||||
assert cos(a[4], rounding=r) == 1
|
|
||||||
assert cos(a[5], rounding=r) > -1
|
|
||||||
assert cos(a[6], rounding=r) == 1
|
|
||||||
assert cos(a[7], rounding=r) > -1
|
|
||||||
assert cos(a[8], rounding=r) == 1
|
|
||||||
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
for f in globals().keys():
|
|
||||||
if f.startswith("test_"):
|
|
||||||
print f
|
|
||||||
globals()[f]()
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
"""
|
|
||||||
Limited tests of the visualization module. Right now it just makes
|
|
||||||
sure that passing custom Axes works.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mpmath import mp, fp
|
|
||||||
|
|
||||||
def test_axes():
|
|
||||||
try:
|
|
||||||
import pylab
|
|
||||||
except ImportError:
|
|
||||||
print "\nSkipping test (pylab not available)\n"
|
|
||||||
return
|
|
||||||
fig = pylab.figure()
|
|
||||||
axes = fig.add_subplot(111)
|
|
||||||
for ctx in [mp, fp]:
|
|
||||||
ctx.plot(lambda x: x**2, [0, 3], axes=axes)
|
|
||||||
assert axes.get_xlabel() == 'x'
|
|
||||||
assert axes.get_ylabel() == 'f(x)'
|
|
||||||
|
|
||||||
fig = pylab.figure()
|
|
||||||
axes = fig.add_subplot(111)
|
|
||||||
for ctx in [mp, fp]:
|
|
||||||
ctx.cplot(lambda z: z, [-2, 2], [-10, 10], axes=axes)
|
|
||||||
assert axes.get_xlabel() == 'Re(z)'
|
|
||||||
assert axes.get_ylabel() == 'Im(z)'
|
|
||||||
|
|
@ -1,229 +0,0 @@
|
||||||
"""
|
|
||||||
Torture tests for asymptotics and high precision evaluation of
|
|
||||||
special functions.
|
|
||||||
|
|
||||||
(Other torture tests may also be placed here.)
|
|
||||||
|
|
||||||
Running this file (gmpy and psyco recommended!) takes several CPU minutes.
|
|
||||||
With Python 2.6+, multiprocessing is used automatically to run tests
|
|
||||||
in parallel if many cores are available. (A single test may take between
|
|
||||||
a second and several minutes; possibly more.)
|
|
||||||
|
|
||||||
The idea:
|
|
||||||
|
|
||||||
* We evaluate functions at positive, negative, imaginary, 45- and 135-degree
|
|
||||||
complex values with magnitudes between 10^-20 to 10^20, at precisions between
|
|
||||||
5 and 150 digits (we can go even higher for fast functions).
|
|
||||||
|
|
||||||
* Comparing the result from two different precision levels provides
|
|
||||||
a strong consistency check (particularly for functions that use
|
|
||||||
different algorithms at different precision levels).
|
|
||||||
|
|
||||||
* That the computation finishes at all (without failure), within reasonable
|
|
||||||
time, provides a check that evaluation works at all: that the code runs,
|
|
||||||
that it doesn't get stuck in an infinite loop, and that it doesn't use
|
|
||||||
some extremely slowly algorithm where it could use a faster one.
|
|
||||||
|
|
||||||
TODO:
|
|
||||||
|
|
||||||
* Speed up those functions that take long to finish!
|
|
||||||
* Generalize to test more cases; more options.
|
|
||||||
* Implement a timeout mechanism.
|
|
||||||
* Some functions are notably absent, including the following:
|
|
||||||
* inverse trigonometric functions (some become inaccurate for complex arguments)
|
|
||||||
* ci, si (not implemented properly for large complex arguments)
|
|
||||||
* zeta functions (need to modify test not to try too large imaginary values)
|
|
||||||
* and others...
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import sys, os
|
|
||||||
from timeit import default_timer as clock
|
|
||||||
|
|
||||||
if "-psyco" in sys.argv:
|
|
||||||
sys.argv.remove('-psyco')
|
|
||||||
import psyco
|
|
||||||
psyco.full()
|
|
||||||
|
|
||||||
if "-nogmpy" in sys.argv:
|
|
||||||
sys.argv.remove('-nogmpy')
|
|
||||||
os.environ['MPMATH_NOGMPY'] = 'Y'
|
|
||||||
|
|
||||||
filt = ''
|
|
||||||
if not sys.argv[-1].endswith(".py"):
|
|
||||||
filt = sys.argv[-1]
|
|
||||||
|
|
||||||
from mpmath import *
|
|
||||||
|
|
||||||
def test_asymp(f, maxdps=150, verbose=False, huge_range=False):
|
|
||||||
dps = [5,15,25,50,90,150,500,1500,5000,10000]
|
|
||||||
dps = [p for p in dps if p <= maxdps]
|
|
||||||
def check(x,y,p,inpt):
|
|
||||||
if abs(x-y)/abs(y) < workprec(20)(power)(10, -p+1):
|
|
||||||
return
|
|
||||||
print
|
|
||||||
print "Error!"
|
|
||||||
print "Input:", inpt
|
|
||||||
print "dps =", p
|
|
||||||
print "Result 1:", x
|
|
||||||
print "Result 2:", y
|
|
||||||
print "Absolute error:", abs(x-y)
|
|
||||||
print "Relative error:", abs(x-y)/abs(y)
|
|
||||||
raise AssertionError
|
|
||||||
exponents = range(-20,20)
|
|
||||||
if huge_range:
|
|
||||||
exponents += [-1000, -100, -50, 50, 100, 1000]
|
|
||||||
for n in exponents:
|
|
||||||
if verbose:
|
|
||||||
print ".",
|
|
||||||
mp.dps = 25
|
|
||||||
xpos = mpf(10)**n / 1.1287
|
|
||||||
xneg = -xpos
|
|
||||||
ximag = xpos*j
|
|
||||||
xcomplex1 = xpos*(1+j)
|
|
||||||
xcomplex2 = xpos*(-1+j)
|
|
||||||
for i in range(len(dps)):
|
|
||||||
if verbose:
|
|
||||||
print "Testing dps = %s" % dps[i]
|
|
||||||
mp.dps = dps[i]
|
|
||||||
new = f(xpos), f(xneg), f(ximag), f(xcomplex1), f(xcomplex2)
|
|
||||||
if i != 0:
|
|
||||||
p = dps[i-1]
|
|
||||||
check(prev[0], new[0], p, xpos)
|
|
||||||
check(prev[1], new[1], p, xneg)
|
|
||||||
check(prev[2], new[2], p, ximag)
|
|
||||||
check(prev[3], new[3], p, xcomplex1)
|
|
||||||
check(prev[4], new[4], p, xcomplex2)
|
|
||||||
prev = new
|
|
||||||
if verbose:
|
|
||||||
print
|
|
||||||
|
|
||||||
a1, a2, a3, a4, a5 = 1.5, -2.25, 3.125, 4, 2
|
|
||||||
|
|
||||||
def test_bernoulli_huge():
|
|
||||||
p, q = bernfrac(9000)
|
|
||||||
assert p % 10**10 == 9636701091
|
|
||||||
assert q == 4091851784687571609141381951327092757255270
|
|
||||||
mp.dps = 15
|
|
||||||
assert str(bernoulli(10**100)) == '-2.58183325604736e+987675256497386331227838638980680030172857347883537824464410652557820800494271520411283004120790908623'
|
|
||||||
mp.dps = 50
|
|
||||||
assert str(bernoulli(10**100)) == '-2.5818332560473632073252488656039475548106223822913e+987675256497386331227838638980680030172857347883537824464410652557820800494271520411283004120790908623'
|
|
||||||
mp.dps = 15
|
|
||||||
|
|
||||||
cases = """\
|
|
||||||
test_bernoulli_huge()
|
|
||||||
test_asymp(lambda z: +pi, maxdps=10000)
|
|
||||||
test_asymp(lambda z: +e, maxdps=10000)
|
|
||||||
test_asymp(lambda z: +ln2, maxdps=10000)
|
|
||||||
test_asymp(lambda z: +ln10, maxdps=10000)
|
|
||||||
test_asymp(lambda z: +phi, maxdps=10000)
|
|
||||||
test_asymp(lambda z: +catalan, maxdps=5000)
|
|
||||||
test_asymp(lambda z: +euler, maxdps=5000)
|
|
||||||
test_asymp(lambda z: +glaisher, maxdps=1000)
|
|
||||||
test_asymp(lambda z: +khinchin, maxdps=1000)
|
|
||||||
test_asymp(lambda z: +twinprime, maxdps=150)
|
|
||||||
test_asymp(lambda z: stieltjes(2), maxdps=150)
|
|
||||||
test_asymp(lambda z: +mertens, maxdps=150)
|
|
||||||
test_asymp(lambda z: +apery, maxdps=5000)
|
|
||||||
test_asymp(sqrt, maxdps=10000, huge_range=True)
|
|
||||||
test_asymp(cbrt, maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(lambda z: root(z,4), maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(lambda z: root(z,-5), maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(exp, maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(expm1, maxdps=1500)
|
|
||||||
test_asymp(ln, maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(cosh, maxdps=5000)
|
|
||||||
test_asymp(sinh, maxdps=5000)
|
|
||||||
test_asymp(tanh, maxdps=1500)
|
|
||||||
test_asymp(sin, maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(cos, maxdps=5000, huge_range=True)
|
|
||||||
test_asymp(tan, maxdps=1500)
|
|
||||||
test_asymp(agm, maxdps=1500, huge_range=True)
|
|
||||||
test_asymp(ellipk, maxdps=1500)
|
|
||||||
test_asymp(ellipe, maxdps=1500)
|
|
||||||
test_asymp(lambertw, huge_range=True)
|
|
||||||
test_asymp(lambda z: lambertw(z,-1))
|
|
||||||
test_asymp(lambda z: lambertw(z,1))
|
|
||||||
test_asymp(lambda z: lambertw(z,4))
|
|
||||||
test_asymp(gamma)
|
|
||||||
test_asymp(loggamma) # huge_range=True ?
|
|
||||||
test_asymp(ei)
|
|
||||||
test_asymp(e1)
|
|
||||||
test_asymp(li, huge_range=True)
|
|
||||||
test_asymp(ci)
|
|
||||||
test_asymp(si)
|
|
||||||
test_asymp(chi)
|
|
||||||
test_asymp(shi)
|
|
||||||
test_asymp(erf)
|
|
||||||
test_asymp(erfc)
|
|
||||||
test_asymp(erfi)
|
|
||||||
test_asymp(lambda z: besselj(2, z))
|
|
||||||
test_asymp(lambda z: bessely(2, z))
|
|
||||||
test_asymp(lambda z: besseli(2, z))
|
|
||||||
test_asymp(lambda z: besselk(2, z))
|
|
||||||
test_asymp(lambda z: besselj(-2.25, z))
|
|
||||||
test_asymp(lambda z: bessely(-2.25, z))
|
|
||||||
test_asymp(lambda z: besseli(-2.25, z))
|
|
||||||
test_asymp(lambda z: besselk(-2.25, z))
|
|
||||||
test_asymp(airyai)
|
|
||||||
test_asymp(airybi)
|
|
||||||
test_asymp(lambda z: hyp0f1(a1, z))
|
|
||||||
test_asymp(lambda z: hyp1f1(a1, a2, z))
|
|
||||||
test_asymp(lambda z: hyp1f2(a1, a2, a3, z))
|
|
||||||
test_asymp(lambda z: hyp2f0(a1, a2, z))
|
|
||||||
test_asymp(lambda z: hyperu(a1, a2, z))
|
|
||||||
test_asymp(lambda z: hyp2f1(a1, a2, a3, z))
|
|
||||||
test_asymp(lambda z: hyp2f2(a1, a2, a3, a4, z))
|
|
||||||
test_asymp(lambda z: hyp2f3(a1, a2, a3, a4, a5, z))
|
|
||||||
test_asymp(lambda z: coulombf(a1, a2, z))
|
|
||||||
test_asymp(lambda z: coulombg(a1, a2, z))
|
|
||||||
test_asymp(lambda z: polylog(2,z))
|
|
||||||
test_asymp(lambda z: polylog(3,z))
|
|
||||||
test_asymp(lambda z: polylog(-2,z))
|
|
||||||
test_asymp(lambda z: expint(4, z))
|
|
||||||
test_asymp(lambda z: expint(-4, z))
|
|
||||||
test_asymp(lambda z: expint(2.25, z))
|
|
||||||
test_asymp(lambda z: gammainc(2.5, z, 5))
|
|
||||||
test_asymp(lambda z: gammainc(2.5, 5, z))
|
|
||||||
test_asymp(lambda z: hermite(3, z))
|
|
||||||
test_asymp(lambda z: hermite(2.5, z))
|
|
||||||
test_asymp(lambda z: legendre(3, z))
|
|
||||||
test_asymp(lambda z: legendre(4, z))
|
|
||||||
test_asymp(lambda z: legendre(2.5, z))
|
|
||||||
test_asymp(lambda z: legenp(a1, a2, z))
|
|
||||||
test_asymp(lambda z: legenq(a1, a2, z), maxdps=90) # abnormally slow
|
|
||||||
test_asymp(lambda z: jtheta(1, z, 0.5))
|
|
||||||
test_asymp(lambda z: jtheta(2, z, 0.5))
|
|
||||||
test_asymp(lambda z: jtheta(3, z, 0.5))
|
|
||||||
test_asymp(lambda z: jtheta(4, z, 0.5))
|
|
||||||
test_asymp(lambda z: jtheta(1, z, 0.5, 1))
|
|
||||||
test_asymp(lambda z: jtheta(2, z, 0.5, 1))
|
|
||||||
test_asymp(lambda z: jtheta(3, z, 0.5, 1))
|
|
||||||
test_asymp(lambda z: jtheta(4, z, 0.5, 1))
|
|
||||||
test_asymp(barnesg, maxdps=90)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def testit(line):
|
|
||||||
if filt in line:
|
|
||||||
print line
|
|
||||||
t1 = clock()
|
|
||||||
exec line
|
|
||||||
t2 = clock()
|
|
||||||
elapsed = t2-t1
|
|
||||||
print "Time:", elapsed, "for", line, "(OK)"
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
try:
|
|
||||||
from multiprocessing import Pool
|
|
||||||
mapf = Pool(None).map
|
|
||||||
print "Running tests with multiprocessing"
|
|
||||||
except ImportError:
|
|
||||||
print "Not using multiprocessing"
|
|
||||||
mapf = map
|
|
||||||
t1 = clock()
|
|
||||||
tasks = cases.splitlines()
|
|
||||||
mapf(testit, tasks)
|
|
||||||
t2 = clock()
|
|
||||||
print "Cumulative wall time:", t2-t1
|
|
||||||
|
|
||||||
|
|
@ -1,91 +0,0 @@
|
||||||
|
|
||||||
def monitor(f, input='print', output='print'):
|
|
||||||
"""
|
|
||||||
Returns a wrapped copy of *f* that monitors evaluation by calling
|
|
||||||
*input* with every input (*args*, *kwargs*) passed to *f* and
|
|
||||||
*output* with every value returned from *f*. The default action
|
|
||||||
(specify using the special string value ``'print'``) is to print
|
|
||||||
inputs and outputs to stdout, along with the total evaluation
|
|
||||||
count::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> mp.dps = 5; mp.pretty = False
|
|
||||||
>>> diff(monitor(exp), 1) # diff will eval f(x-h) and f(x+h)
|
|
||||||
in 0 (mpf('0.99999999906867742538452148'),) {}
|
|
||||||
out 0 mpf('2.7182818259274480055282064')
|
|
||||||
in 1 (mpf('1.0000000009313225746154785'),) {}
|
|
||||||
out 1 mpf('2.7182818309906424675501024')
|
|
||||||
mpf('2.7182808')
|
|
||||||
|
|
||||||
To disable either the input or the output handler, you may
|
|
||||||
pass *None* as argument.
|
|
||||||
|
|
||||||
Custom input and output handlers may be used e.g. to store
|
|
||||||
results for later analysis::
|
|
||||||
|
|
||||||
>>> mp.dps = 15
|
|
||||||
>>> input = []
|
|
||||||
>>> output = []
|
|
||||||
>>> findroot(monitor(sin, input.append, output.append), 3.0)
|
|
||||||
mpf('3.1415926535897932')
|
|
||||||
>>> len(input) # Count number of evaluations
|
|
||||||
9
|
|
||||||
>>> print input[3], output[3]
|
|
||||||
((mpf('3.1415076583334066'),), {}) 8.49952562843408e-5
|
|
||||||
>>> print input[4], output[4]
|
|
||||||
((mpf('3.1415928201669122'),), {}) -1.66577118985331e-7
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not input:
|
|
||||||
input = lambda v: None
|
|
||||||
elif input == 'print':
|
|
||||||
incount = [0]
|
|
||||||
def input(value):
|
|
||||||
args, kwargs = value
|
|
||||||
print "in %s %r %r" % (incount[0], args, kwargs)
|
|
||||||
incount[0] += 1
|
|
||||||
if not output:
|
|
||||||
output = lambda v: None
|
|
||||||
elif output == 'print':
|
|
||||||
outcount = [0]
|
|
||||||
def output(value):
|
|
||||||
print "out %s %r" % (outcount[0], value)
|
|
||||||
outcount[0] += 1
|
|
||||||
def f_monitored(*args, **kwargs):
|
|
||||||
input((args, kwargs))
|
|
||||||
v = f(*args, **kwargs)
|
|
||||||
output(v)
|
|
||||||
return v
|
|
||||||
return f_monitored
|
|
||||||
|
|
||||||
def timing(f, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Returns time elapsed for evaluating ``f()``. Optionally arguments
|
|
||||||
may be passed to time the execution of ``f(*args, **kwargs)``.
|
|
||||||
|
|
||||||
If the first call is very quick, ``f`` is called
|
|
||||||
repeatedly and the best time is returned.
|
|
||||||
"""
|
|
||||||
once = kwargs.get('once')
|
|
||||||
if 'once' in kwargs:
|
|
||||||
del kwargs['once']
|
|
||||||
if args or kwargs:
|
|
||||||
if len(args) == 1 and not kwargs:
|
|
||||||
arg = args[0]
|
|
||||||
g = lambda: f(arg)
|
|
||||||
else:
|
|
||||||
g = lambda: f(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
g = f
|
|
||||||
from timeit import default_timer as clock
|
|
||||||
t1=clock(); v=g(); t2=clock(); t=t2-t1
|
|
||||||
if t > 0.05 or once:
|
|
||||||
return t
|
|
||||||
for i in range(3):
|
|
||||||
t1=clock();
|
|
||||||
# Evaluate multiple times because the timer function
|
|
||||||
# has a significant overhead
|
|
||||||
g();g();g();g();g();g();g();g();g();g()
|
|
||||||
t2=clock()
|
|
||||||
t=min(t,(t2-t1)/10)
|
|
||||||
return t
|
|
||||||
|
|
@ -1,270 +0,0 @@
|
||||||
"""
|
|
||||||
Plotting (requires matplotlib)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from colorsys import hsv_to_rgb, hls_to_rgb
|
|
||||||
|
|
||||||
class VisualizationMethods(object):
|
|
||||||
plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError)
|
|
||||||
|
|
||||||
def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
|
|
||||||
singularities=[], axes=None):
|
|
||||||
r"""
|
|
||||||
Shows a simple 2D plot of a function `f(x)` or list of functions
|
|
||||||
`[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
|
|
||||||
specified by *xlim*. Some examples::
|
|
||||||
|
|
||||||
plot(lambda x: exp(x)*li(x), [1, 4])
|
|
||||||
plot([cos, sin], [-4, 4])
|
|
||||||
plot([fresnels, fresnelc], [-4, 4])
|
|
||||||
plot([sqrt, cbrt], [-4, 4])
|
|
||||||
plot(lambda t: zeta(0.5+t*j), [-20, 20])
|
|
||||||
plot([floor, ceil, abs, sign], [-5, 5])
|
|
||||||
|
|
||||||
Points where the function raises a numerical exception or
|
|
||||||
returns an infinite value are removed from the graph.
|
|
||||||
Singularities can also be excluded explicitly
|
|
||||||
as follows (useful for removing erroneous vertical lines)::
|
|
||||||
|
|
||||||
plot(cot, ylim=[-5, 5]) # bad
|
|
||||||
plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good
|
|
||||||
|
|
||||||
For parts where the function assumes complex values, the
|
|
||||||
real part is plotted with dashes and the imaginary part
|
|
||||||
is plotted with dots.
|
|
||||||
|
|
||||||
NOTE: This function requires matplotlib (pylab).
|
|
||||||
"""
|
|
||||||
if file:
|
|
||||||
axes = None
|
|
||||||
fig = None
|
|
||||||
if not axes:
|
|
||||||
import pylab
|
|
||||||
fig = pylab.figure()
|
|
||||||
axes = fig.add_subplot(111)
|
|
||||||
if not isinstance(f, (tuple, list)):
|
|
||||||
f = [f]
|
|
||||||
a, b = xlim
|
|
||||||
colors = ['b', 'r', 'g', 'm', 'k']
|
|
||||||
for n, func in enumerate(f):
|
|
||||||
x = ctx.arange(a, b, (b-a)/float(points))
|
|
||||||
segments = []
|
|
||||||
segment = []
|
|
||||||
in_complex = False
|
|
||||||
for i in xrange(len(x)):
|
|
||||||
try:
|
|
||||||
if i != 0:
|
|
||||||
for sing in singularities:
|
|
||||||
if x[i-1] <= sing and x[i] >= sing:
|
|
||||||
raise ValueError
|
|
||||||
v = func(x[i])
|
|
||||||
if ctx.isnan(v) or abs(v) > 1e300:
|
|
||||||
raise ValueError
|
|
||||||
if hasattr(v, "imag") and v.imag:
|
|
||||||
re = float(v.real)
|
|
||||||
im = float(v.imag)
|
|
||||||
if not in_complex:
|
|
||||||
in_complex = True
|
|
||||||
segments.append(segment)
|
|
||||||
segment = []
|
|
||||||
segment.append((float(x[i]), re, im))
|
|
||||||
else:
|
|
||||||
if in_complex:
|
|
||||||
in_complex = False
|
|
||||||
segments.append(segment)
|
|
||||||
segment = []
|
|
||||||
segment.append((float(x[i]), v))
|
|
||||||
except ctx.plot_ignore:
|
|
||||||
if segment:
|
|
||||||
segments.append(segment)
|
|
||||||
segment = []
|
|
||||||
if segment:
|
|
||||||
segments.append(segment)
|
|
||||||
for segment in segments:
|
|
||||||
x = [s[0] for s in segment]
|
|
||||||
y = [s[1] for s in segment]
|
|
||||||
if not x:
|
|
||||||
continue
|
|
||||||
c = colors[n % len(colors)]
|
|
||||||
if len(segment[0]) == 3:
|
|
||||||
z = [s[2] for s in segment]
|
|
||||||
axes.plot(x, y, '--'+c, linewidth=3)
|
|
||||||
axes.plot(x, z, ':'+c, linewidth=3)
|
|
||||||
else:
|
|
||||||
axes.plot(x, y, c, linewidth=3)
|
|
||||||
axes.set_xlim(xlim)
|
|
||||||
if ylim:
|
|
||||||
axes.set_ylim(ylim)
|
|
||||||
axes.set_xlabel('x')
|
|
||||||
axes.set_ylabel('f(x)')
|
|
||||||
axes.grid(True)
|
|
||||||
if fig:
|
|
||||||
if file:
|
|
||||||
pylab.savefig(file, dpi=dpi)
|
|
||||||
else:
|
|
||||||
pylab.show()
|
|
||||||
|
|
||||||
def default_color_function(ctx, z):
|
|
||||||
if ctx.isinf(z):
|
|
||||||
return (1.0, 1.0, 1.0)
|
|
||||||
if ctx.isnan(z):
|
|
||||||
return (0.5, 0.5, 0.5)
|
|
||||||
pi = 3.1415926535898
|
|
||||||
a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
|
|
||||||
a = (a + 0.5) % 1.0
|
|
||||||
b = 1.0 - float(1/(1.0+abs(z)**0.3))
|
|
||||||
return hls_to_rgb(a, b, 0.8)
|
|
||||||
|
|
||||||
def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
|
|
||||||
verbose=False, file=None, dpi=None, axes=None):
|
|
||||||
"""
|
|
||||||
Plots the given complex-valued function *f* over a rectangular part
|
|
||||||
of the complex plane specified by the pairs of intervals *re* and *im*.
|
|
||||||
For example::
|
|
||||||
|
|
||||||
cplot(lambda z: z, [-2, 2], [-10, 10])
|
|
||||||
cplot(exp)
|
|
||||||
cplot(zeta, [0, 1], [0, 50])
|
|
||||||
|
|
||||||
By default, the complex argument (phase) is shown as color (hue) and
|
|
||||||
the magnitude is show as brightness. You can also supply a
|
|
||||||
custom color function (*color*). This function should take a
|
|
||||||
complex number as input and return an RGB 3-tuple containing
|
|
||||||
floats in the range 0.0-1.0.
|
|
||||||
|
|
||||||
To obtain a sharp image, the number of points may need to be
|
|
||||||
increased to 100,000 or thereabout. Since evaluating the
|
|
||||||
function that many times is likely to be slow, the 'verbose'
|
|
||||||
option is useful to display progress.
|
|
||||||
|
|
||||||
NOTE: This function requires matplotlib (pylab).
|
|
||||||
"""
|
|
||||||
if color is None:
|
|
||||||
color = ctx.default_color_function
|
|
||||||
import pylab
|
|
||||||
if file:
|
|
||||||
axes = None
|
|
||||||
fig = None
|
|
||||||
if not axes:
|
|
||||||
fig = pylab.figure()
|
|
||||||
axes = fig.add_subplot(111)
|
|
||||||
rea, reb = re
|
|
||||||
ima, imb = im
|
|
||||||
dre = reb - rea
|
|
||||||
dim = imb - ima
|
|
||||||
M = int(ctx.sqrt(points*dre/dim)+1)
|
|
||||||
N = int(ctx.sqrt(points*dim/dre)+1)
|
|
||||||
x = pylab.linspace(rea, reb, M)
|
|
||||||
y = pylab.linspace(ima, imb, N)
|
|
||||||
# Note: we have to be careful to get the right rotation.
|
|
||||||
# Test with these plots:
|
|
||||||
# cplot(lambda z: z if z.real < 0 else 0)
|
|
||||||
# cplot(lambda z: z if z.imag < 0 else 0)
|
|
||||||
w = pylab.zeros((N, M, 3))
|
|
||||||
for n in xrange(N):
|
|
||||||
for m in xrange(M):
|
|
||||||
z = ctx.mpc(x[m], y[n])
|
|
||||||
try:
|
|
||||||
v = color(f(z))
|
|
||||||
except ctx.plot_ignore:
|
|
||||||
v = (0.5, 0.5, 0.5)
|
|
||||||
w[n,m] = v
|
|
||||||
if verbose:
|
|
||||||
print n, "of", N
|
|
||||||
axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
|
|
||||||
axes.set_xlabel('Re(z)')
|
|
||||||
axes.set_ylabel('Im(z)')
|
|
||||||
if fig:
|
|
||||||
if file:
|
|
||||||
pylab.savefig(file, dpi=dpi)
|
|
||||||
else:
|
|
||||||
pylab.show()
|
|
||||||
|
|
||||||
def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
|
|
||||||
wireframe=False, file=None, dpi=None, axes=None):
|
|
||||||
"""
|
|
||||||
Plots the surface defined by `f`.
|
|
||||||
|
|
||||||
If `f` returns a single component, then this plots the surface
|
|
||||||
defined by `z = f(x,y)` over the rectangular domain with
|
|
||||||
`x = u` and `y = v`.
|
|
||||||
|
|
||||||
If `f` returns three components, then this plots the parametric
|
|
||||||
surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
|
|
||||||
|
|
||||||
For example, to plot a simple function::
|
|
||||||
|
|
||||||
>>> from mpmath import *
|
|
||||||
>>> f = lambda x, y: sin(x+y)*cos(y)
|
|
||||||
>>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP
|
|
||||||
|
|
||||||
Plotting a donut::
|
|
||||||
|
|
||||||
>>> r, R = 1, 2.5
|
|
||||||
>>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
|
|
||||||
>>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP
|
|
||||||
|
|
||||||
NOTE: This function requires matplotlib (pylab) 0.98.5.3 or higher.
|
|
||||||
"""
|
|
||||||
import pylab
|
|
||||||
import mpl_toolkits.mplot3d as mplot3d
|
|
||||||
if file:
|
|
||||||
axes = None
|
|
||||||
fig = None
|
|
||||||
if not axes:
|
|
||||||
fig = pylab.figure()
|
|
||||||
axes = mplot3d.axes3d.Axes3D(fig)
|
|
||||||
ua, ub = u
|
|
||||||
va, vb = v
|
|
||||||
du = ub - ua
|
|
||||||
dv = vb - va
|
|
||||||
if not isinstance(points, (list, tuple)):
|
|
||||||
points = [points, points]
|
|
||||||
M, N = points
|
|
||||||
u = pylab.linspace(ua, ub, M)
|
|
||||||
v = pylab.linspace(va, vb, N)
|
|
||||||
x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
|
|
||||||
xab, yab, zab = [[0, 0] for i in xrange(3)]
|
|
||||||
for n in xrange(N):
|
|
||||||
for m in xrange(M):
|
|
||||||
fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
|
|
||||||
try:
|
|
||||||
x[m,n], y[m,n], z[m,n] = fdata
|
|
||||||
except TypeError:
|
|
||||||
x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
|
|
||||||
for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
|
|
||||||
if c < cab[0]:
|
|
||||||
cab[0] = c
|
|
||||||
if c > cab[1]:
|
|
||||||
cab[1] = c
|
|
||||||
if wireframe:
|
|
||||||
axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
|
|
||||||
else:
|
|
||||||
axes.plot_surface(x, y, z, rstride=4, cstride=4)
|
|
||||||
axes.set_xlabel('x')
|
|
||||||
axes.set_ylabel('y')
|
|
||||||
axes.set_zlabel('z')
|
|
||||||
if keep_aspect:
|
|
||||||
dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
|
|
||||||
maxd = max(dx, dy, dz)
|
|
||||||
if dx < maxd:
|
|
||||||
delta = maxd - dx
|
|
||||||
axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
|
|
||||||
if dy < maxd:
|
|
||||||
delta = maxd - dy
|
|
||||||
axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
|
|
||||||
if dz < maxd:
|
|
||||||
delta = maxd - dz
|
|
||||||
axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
|
|
||||||
if fig:
|
|
||||||
if file:
|
|
||||||
pylab.savefig(file, dpi=dpi)
|
|
||||||
else:
|
|
||||||
pylab.show()
|
|
||||||
|
|
||||||
|
|
||||||
VisualizationMethods.plot = plot
|
|
||||||
VisualizationMethods.default_color_function = default_color_function
|
|
||||||
VisualizationMethods.cplot = cplot
|
|
||||||
VisualizationMethods.splot = splot
|
|
||||||
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
# -*- coding: ISO-8859-1 -*-
|
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# Copyright (C) 2002-2005 Jörg Lehmann <joergl@users.sourceforge.net>
|
# Copyright (C) 2002-2005 Jorg Lehmann <joergl@users.sourceforge.net>
|
||||||
# Copyright (C) 2002-2006 André Wobst <wobsta@users.sourceforge.net>
|
# Copyright (C) 2002-2006 Andre Wobst <wobsta@users.sourceforge.net>
|
||||||
#
|
#
|
||||||
# This file is part of PyX (http://pyx.sourceforge.net/).
|
# This file is part of PyX (http://pyx.sourceforge.net/).
|
||||||
#
|
#
|
||||||
|
|
@ -28,8 +27,8 @@ interface. Complex tasks like 2d and 3d plots in publication-ready quality are
|
||||||
built out of these primitives.
|
built out of these primitives.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import version
|
from .version import version
|
||||||
__version__ = version.version
|
__version__ = version
|
||||||
|
|
||||||
__all__ = ["attr", "box", "bitmap", "canvas", "color", "connector", "deco", "deformer", "document",
|
__all__ = ["attr", "box", "bitmap", "canvas", "color", "connector", "deco", "deformer", "document",
|
||||||
"epsfile", "graph", "mesh", "path", "pattern", "style", "trafo", "text", "unit"]
|
"epsfile", "graph", "mesh", "path", "pattern", "style", "trafo", "text", "unit"]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
#!/usr/bin/env python2.7
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
This script will generate a stimulus file for a given period, load, and slew input
|
This script will generate a stimulus file for a given period, load, and slew input
|
||||||
for the given dimension SRAM. It is useful for debugging after an SRAM has been
|
for the given dimension SRAM. It is useful for debugging after an SRAM has been
|
||||||
|
|
|
||||||
|
|
@ -89,11 +89,11 @@ def print_banner():
|
||||||
def check_versions():
|
def check_versions():
|
||||||
""" Run some checks of required software versions. """
|
""" Run some checks of required software versions. """
|
||||||
|
|
||||||
# check that we are not using version 3 and at least 2.7
|
# Now require python >=3.6
|
||||||
major_python_version = sys.version_info.major
|
major_python_version = sys.version_info.major
|
||||||
minor_python_version = sys.version_info.minor
|
minor_python_version = sys.version_info.minor
|
||||||
if not (major_python_version == 2 and minor_python_version >= 7):
|
if not (major_python_version == 3 and minor_python_version >= 6):
|
||||||
debug.error("Python 2.7 is required.",-1)
|
debug.error("Python 3.6 or greater is required.",-1)
|
||||||
|
|
||||||
# FIXME: Check versions of other tools here??
|
# FIXME: Check versions of other tools here??
|
||||||
# or, this could be done in each module (e.g. verify, characterizer, etc.)
|
# or, this could be done in each module (e.g. verify, characterizer, etc.)
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class bank(design.design):
|
||||||
"bitcell_array", "sense_amp_array", "precharge_array",
|
"bitcell_array", "sense_amp_array", "precharge_array",
|
||||||
"column_mux_array", "write_driver_array", "tri_gate_array",
|
"column_mux_array", "write_driver_array", "tri_gate_array",
|
||||||
"bank_select"]
|
"bank_select"]
|
||||||
|
from importlib import reload
|
||||||
for mod_name in mod_list:
|
for mod_name in mod_list:
|
||||||
config_mod_name = getattr(OPTS, mod_name)
|
config_mod_name = getattr(OPTS, mod_name)
|
||||||
class_file = reload(__import__(config_mod_name))
|
class_file = reload(__import__(config_mod_name))
|
||||||
|
|
@ -130,8 +131,8 @@ class bank(design.design):
|
||||||
def compute_sizes(self):
|
def compute_sizes(self):
|
||||||
""" Computes the required sizes to create the bank """
|
""" Computes the required sizes to create the bank """
|
||||||
|
|
||||||
self.num_cols = self.words_per_row*self.word_size
|
self.num_cols = int(self.words_per_row*self.word_size)
|
||||||
self.num_rows = self.num_words / self.words_per_row
|
self.num_rows = int(self.num_words / self.words_per_row)
|
||||||
|
|
||||||
self.row_addr_size = int(log(self.num_rows, 2))
|
self.row_addr_size = int(log(self.num_rows, 2))
|
||||||
self.col_addr_size = int(log(self.words_per_row, 2))
|
self.col_addr_size = int(log(self.words_per_row, 2))
|
||||||
|
|
@ -320,7 +321,7 @@ class bank(design.design):
|
||||||
y_offset = self.sense_amp_array.height+self.column_mux_height \
|
y_offset = self.sense_amp_array.height+self.column_mux_height \
|
||||||
+ self.write_driver_array.height + self.m2_gap + self.tri_gate_array.height
|
+ self.write_driver_array.height + self.m2_gap + self.tri_gate_array.height
|
||||||
self.tri_gate_array_inst=self.add_inst(name="tri_gate_array",
|
self.tri_gate_array_inst=self.add_inst(name="tri_gate_array",
|
||||||
mod=self.tri_gate_array,
|
mod=self.tri_gate_array,
|
||||||
offset=vector(0,y_offset).scale(-1,-1))
|
offset=vector(0,y_offset).scale(-1,-1))
|
||||||
|
|
||||||
temp = []
|
temp = []
|
||||||
|
|
@ -852,9 +853,7 @@ class bank(design.design):
|
||||||
|
|
||||||
def analytical_delay(self, slew, load):
|
def analytical_delay(self, slew, load):
|
||||||
""" return analytical delay of the bank"""
|
""" return analytical delay of the bank"""
|
||||||
msf_addr_delay = self.msf_address.analytical_delay(slew, self.row_decoder.input_load())
|
decoder_delay = self.row_decoder.analytical_delay(slew, self.wordline_driver.input_load())
|
||||||
|
|
||||||
decoder_delay = self.row_decoder.analytical_delay(msf_addr_delay.slew, self.wordline_driver.input_load())
|
|
||||||
|
|
||||||
word_driver_delay = self.wordline_driver.analytical_delay(decoder_delay.slew, self.bitcell_array.input_load())
|
word_driver_delay = self.wordline_driver.analytical_delay(decoder_delay.slew, self.bitcell_array.input_load())
|
||||||
|
|
||||||
|
|
@ -866,7 +865,6 @@ class bank(design.design):
|
||||||
|
|
||||||
data_t_DATA_delay = self.tri_gate_array.analytical_delay(bl_t_data_out_delay.slew, load)
|
data_t_DATA_delay = self.tri_gate_array.analytical_delay(bl_t_data_out_delay.slew, load)
|
||||||
|
|
||||||
result = msf_addr_delay + decoder_delay + word_driver_delay \
|
result = decoder_delay + word_driver_delay + bitcell_array_delay + bl_t_data_out_delay + data_t_DATA_delay
|
||||||
+ bitcell_array_delay + bl_t_data_out_delay + data_t_DATA_delay
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ class control_logic(design.design):
|
||||||
self.inv8 = pinv(size=16, height=dff_height)
|
self.inv8 = pinv(size=16, height=dff_height)
|
||||||
self.add_mod(self.inv8)
|
self.add_mod(self.inv8)
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.replica_bitline))
|
c = reload(__import__(OPTS.replica_bitline))
|
||||||
replica_bitline = getattr(c, OPTS.replica_bitline)
|
replica_bitline = getattr(c, OPTS.replica_bitline)
|
||||||
# FIXME: These should be tuned according to the size!
|
# FIXME: These should be tuned according to the size!
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class delay_chain(design.design):
|
||||||
self.num_inverters = 1 + sum(fanout_list)
|
self.num_inverters = 1 + sum(fanout_list)
|
||||||
self.num_top_half = round(self.num_inverters / 2.0)
|
self.num_top_half = round(self.num_inverters / 2.0)
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.bitcell))
|
c = reload(__import__(OPTS.bitcell))
|
||||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||||
self.bitcell = self.mod_bitcell()
|
self.bitcell = self.mod_bitcell()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class dff_array(design.design):
|
||||||
design.design.__init__(self, name)
|
design.design.__init__(self, name)
|
||||||
debug.info(1, "Creating {}".format(self.name))
|
debug.info(1, "Creating {}".format(self.name))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.dff))
|
c = reload(__import__(OPTS.dff))
|
||||||
self.mod_dff = getattr(c, OPTS.dff)
|
self.mod_dff = getattr(c, OPTS.dff)
|
||||||
self.dff = self.mod_dff("dff")
|
self.dff = self.mod_dff("dff")
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class dff_buf(design.design):
|
||||||
design.design.__init__(self, name)
|
design.design.__init__(self, name)
|
||||||
debug.info(1, "Creating {}".format(self.name))
|
debug.info(1, "Creating {}".format(self.name))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.dff))
|
c = reload(__import__(OPTS.dff))
|
||||||
self.mod_dff = getattr(c, OPTS.dff)
|
self.mod_dff = getattr(c, OPTS.dff)
|
||||||
self.dff = self.mod_dff("dff")
|
self.dff = self.mod_dff("dff")
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ class dff_inv(design.design):
|
||||||
design.design.__init__(self, name)
|
design.design.__init__(self, name)
|
||||||
debug.info(1, "Creating {}".format(self.name))
|
debug.info(1, "Creating {}".format(self.name))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.dff))
|
c = reload(__import__(OPTS.dff))
|
||||||
self.mod_dff = getattr(c, OPTS.dff)
|
self.mod_dff = getattr(c, OPTS.dff)
|
||||||
self.dff = self.mod_dff("dff")
|
self.dff = self.mod_dff("dff")
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ class hierarchical_decoder(design.design):
|
||||||
def __init__(self, rows):
|
def __init__(self, rows):
|
||||||
design.design.__init__(self, "hierarchical_decoder_{0}rows".format(rows))
|
design.design.__init__(self, "hierarchical_decoder_{0}rows".format(rows))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.bitcell))
|
c = reload(__import__(OPTS.bitcell))
|
||||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||||
self.bitcell_height = self.mod_bitcell.height
|
self.bitcell_height = self.mod_bitcell.height
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ class hierarchical_predecode(design.design):
|
||||||
self.number_of_outputs = int(math.pow(2, self.number_of_inputs))
|
self.number_of_outputs = int(math.pow(2, self.number_of_inputs))
|
||||||
design.design.__init__(self, name="pre{0}x{1}".format(self.number_of_inputs,self.number_of_outputs))
|
design.design.__init__(self, name="pre{0}x{1}".format(self.number_of_inputs,self.number_of_outputs))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.bitcell))
|
c = reload(__import__(OPTS.bitcell))
|
||||||
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
self.mod_bitcell = getattr(c, OPTS.bitcell)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class ms_flop_array(design.design):
|
||||||
design.design.__init__(self, name)
|
design.design.__init__(self, name)
|
||||||
debug.info(1, "Creating {}".format(self.name))
|
debug.info(1, "Creating {}".format(self.name))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.ms_flop))
|
c = reload(__import__(OPTS.ms_flop))
|
||||||
self.mod_ms_flop = getattr(c, OPTS.ms_flop)
|
self.mod_ms_flop = getattr(c, OPTS.ms_flop)
|
||||||
self.ms = self.mod_ms_flop("ms_flop")
|
self.ms = self.mod_ms_flop("ms_flop")
|
||||||
|
|
@ -27,7 +28,7 @@ class ms_flop_array(design.design):
|
||||||
|
|
||||||
self.width = self.columns * self.ms.width
|
self.width = self.columns * self.ms.width
|
||||||
self.height = self.ms.height
|
self.height = self.ms.height
|
||||||
self.words_per_row = self.columns / self.word_size
|
self.words_per_row = int(self.columns / self.word_size)
|
||||||
|
|
||||||
self.create_layout()
|
self.create_layout()
|
||||||
|
|
||||||
|
|
@ -57,13 +58,16 @@ class ms_flop_array(design.design):
|
||||||
else:
|
else:
|
||||||
base = vector((i+1)*self.ms.width,0)
|
base = vector((i+1)*self.ms.width,0)
|
||||||
mirror = "MY"
|
mirror = "MY"
|
||||||
self.ms_inst[i/self.words_per_row]=self.add_inst(name=name,
|
|
||||||
|
index = int(i/self.words_per_row)
|
||||||
|
|
||||||
|
self.ms_inst[index]=self.add_inst(name=name,
|
||||||
mod=self.ms,
|
mod=self.ms,
|
||||||
offset=base,
|
offset=base,
|
||||||
mirror=mirror)
|
mirror=mirror)
|
||||||
self.connect_inst(["din[{0}]".format(i/self.words_per_row),
|
self.connect_inst(["din[{0}]".format(index),
|
||||||
"dout[{0}]".format(i/self.words_per_row),
|
"dout[{0}]".format(index),
|
||||||
"dout_bar[{0}]".format(i/self.words_per_row),
|
"dout_bar[{0}]".format(index),
|
||||||
"clk",
|
"clk",
|
||||||
"vdd", "gnd"])
|
"vdd", "gnd"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ class replica_bitline(design.design):
|
||||||
def __init__(self, delay_stages, delay_fanout, bitcell_loads, name="replica_bitline"):
|
def __init__(self, delay_stages, delay_fanout, bitcell_loads, name="replica_bitline"):
|
||||||
design.design.__init__(self, name)
|
design.design.__init__(self, name)
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
g = reload(__import__(OPTS.delay_chain))
|
g = reload(__import__(OPTS.delay_chain))
|
||||||
self.mod_delay_chain = getattr(g, OPTS.delay_chain)
|
self.mod_delay_chain = getattr(g, OPTS.delay_chain)
|
||||||
|
|
||||||
|
|
@ -132,11 +133,10 @@ class replica_bitline(design.design):
|
||||||
""" Connect all the signals together """
|
""" Connect all the signals together """
|
||||||
self.route_vdd()
|
self.route_vdd()
|
||||||
self.route_gnd()
|
self.route_gnd()
|
||||||
|
self.route_vdd_gnd()
|
||||||
self.route_access_tx()
|
self.route_access_tx()
|
||||||
|
|
||||||
def route_vdd_gnd(self):
|
def route_vdd_gnd(self):
|
||||||
""" Route all the vdd and gnd pins to the top level """
|
|
||||||
def route_vdd_gnd(self):
|
|
||||||
""" Propagate all vdd/gnd pins up to this level for all modules """
|
""" Propagate all vdd/gnd pins up to this level for all modules """
|
||||||
|
|
||||||
# These are the instances that every bank has
|
# These are the instances that every bank has
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ class sense_amp_array(design.design):
|
||||||
design.design.__init__(self, "sense_amp_array")
|
design.design.__init__(self, "sense_amp_array")
|
||||||
debug.info(1, "Creating {0}".format(self.name))
|
debug.info(1, "Creating {0}".format(self.name))
|
||||||
|
|
||||||
|
from importlib import reload
|
||||||
c = reload(__import__(OPTS.sense_amp))
|
c = reload(__import__(OPTS.sense_amp))
|
||||||
self.mod_sense_amp = getattr(c, OPTS.sense_amp)
|
self.mod_sense_amp = getattr(c, OPTS.sense_amp)
|
||||||
self.amp = self.mod_sense_amp("sense_amp")
|
self.amp = self.mod_sense_amp("sense_amp")
|
||||||
|
|
@ -33,7 +34,8 @@ class sense_amp_array(design.design):
|
||||||
def add_pins(self):
|
def add_pins(self):
|
||||||
|
|
||||||
for i in range(0,self.row_size,self.words_per_row):
|
for i in range(0,self.row_size,self.words_per_row):
|
||||||
self.add_pin("data[{0}]".format(i/self.words_per_row))
|
index = int(i/self.words_per_row)
|
||||||
|
self.add_pin("data[{0}]".format(index))
|
||||||
self.add_pin("bl[{0}]".format(i))
|
self.add_pin("bl[{0}]".format(i))
|
||||||
self.add_pin("br[{0}]".format(i))
|
self.add_pin("br[{0}]".format(i))
|
||||||
|
|
||||||
|
|
@ -62,12 +64,14 @@ class sense_amp_array(design.design):
|
||||||
br_offset = amp_position + br_pin.ll().scale(1,0)
|
br_offset = amp_position + br_pin.ll().scale(1,0)
|
||||||
dout_offset = amp_position + dout_pin.ll()
|
dout_offset = amp_position + dout_pin.ll()
|
||||||
|
|
||||||
|
index = int(i/self.words_per_row)
|
||||||
|
|
||||||
inst = self.add_inst(name=name,
|
inst = self.add_inst(name=name,
|
||||||
mod=self.amp,
|
mod=self.amp,
|
||||||
offset=amp_position)
|
offset=amp_position)
|
||||||
self.connect_inst(["bl[{0}]".format(i),
|
self.connect_inst(["bl[{0}]".format(i),
|
||||||
"br[{0}]".format(i),
|
"br[{0}]".format(i),
|
||||||
"data[{0}]".format(i/self.words_per_row),
|
"data[{0}]".format(index),
|
||||||
"en", "vdd", "gnd"])
|
"en", "vdd", "gnd"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,19 +89,18 @@ class sense_amp_array(design.design):
|
||||||
layer="metal3",
|
layer="metal3",
|
||||||
offset=vdd_pos)
|
offset=vdd_pos)
|
||||||
|
|
||||||
|
self.add_layout_pin(text="bl[{0}]".format(i),
|
||||||
self.add_layout_pin(text="bl[{0}]".format(i/self.words_per_row),
|
|
||||||
layer="metal2",
|
layer="metal2",
|
||||||
offset=bl_offset,
|
offset=bl_offset,
|
||||||
width=bl_pin.width(),
|
width=bl_pin.width(),
|
||||||
height=bl_pin.height())
|
height=bl_pin.height())
|
||||||
self.add_layout_pin(text="br[{0}]".format(i/self.words_per_row),
|
self.add_layout_pin(text="br[{0}]".format(i),
|
||||||
layer="metal2",
|
layer="metal2",
|
||||||
offset=br_offset,
|
offset=br_offset,
|
||||||
width=br_pin.width(),
|
width=br_pin.width(),
|
||||||
height=br_pin.height())
|
height=br_pin.height())
|
||||||
|
|
||||||
self.add_layout_pin(text="data[{0}]".format(i/self.words_per_row),
|
self.add_layout_pin(text="data[{0}]".format(index),
|
||||||
layer="metal2",
|
layer="metal2",
|
||||||
offset=dout_offset,
|
offset=dout_offset,
|
||||||
width=dout_pin.width(),
|
width=dout_pin.width(),
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ class single_level_column_mux_array(design.design):
|
||||||
debug.info(1, "Creating {0}".format(self.name))
|
debug.info(1, "Creating {0}".format(self.name))
|
||||||
self.columns = columns
|
self.columns = columns
|
||||||
self.word_size = word_size
|
self.word_size = word_size
|
||||||
self.words_per_row = self.columns / self.word_size
|
self.words_per_row = int(self.columns / self.word_size)
|
||||||
self.add_pins()
|
self.add_pins()
|
||||||
self.create_layout()
|
self.create_layout()
|
||||||
self.DRC_LVS()
|
self.DRC_LVS()
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue