原文地址: https://www.zhouwenzhen.top/archives/48/
使用Python生成LaTeX 數(shù)學(xué)公式
在閱讀算法文獻或者數(shù)學(xué)相關(guān)的文章中經(jīng)常會看到一些簡單或復(fù)雜的數(shù)學(xué)公式,最近在分享此類文章時,想使用LaTex鍵入數(shù)學(xué)公式以美化閱讀,發(fā)現(xiàn)需要反復(fù)去查詢LaTex相關(guān)的語法嗜傅,效率較低且容易出錯千埃。
最近 GitHub 上出現(xiàn)了一個開源項目 latexify_py春叫,它使用 Python 就能生成 LaTeX 數(shù)學(xué)公式亭螟。打開Google Colaboratory示例列舉了幾個案例:
先試試看
在本地安裝相應(yīng)的Python包残家,Python版本 >= 3.6
pip install latexify-py
參考官方示例進行測試:
import math
import latexify
@latexify.with_latex
def solve(a, b, c):
return (-b + math.sqrt(b ** 2 - 4 * a * c)) / (2 * a)
if __name__ == '__main__':
print(solve)
終端打印結(jié)果為:
\mathrm{solve}(a, b, c)\triangleq \frac{-b + \sqrt{b^{2} - 4ac}}{2a}
將打印結(jié)果輸入到支持LaTeX的編輯器中榆俺,以Typora為例。選擇插入公式塊:
于是,把最近閱讀的facebook開源的prophet時間序列預(yù)測算法提到的飽和增長模型公式進行測試茴晋,原文中為
開始在python中鍵入代碼:
@latexify.with_latex
def g(t):
return C(t) / (1 + exp(1-(k + alpha(t) ** T * delta) * (t -(m + alpha(t) ** T * gamma))))
終端打印結(jié)果并輸入Typora為:
\mathrm{g}(t)\triangleq \frac{\mathrm{C}\left(t\right)}{1 + \mathrm{exp}\left(1 - (k + \mathrm{{\alpha}}\left(t\right)^{t}{\delta})(t - m + \mathrm{{\alpha}}\left(t\right)^{T}{\gamma})\right)}
對比發(fā)現(xiàn)python輸出的公式中有一個錯誤:刪除了一個括號陪捷,而python代碼中是包含的,由
變成了:
為了進一步驗證上面出現(xiàn)的問題诺擅,輸入一段很簡單的代碼:
@latexify.with_latex
def test(a, b):
return - (a + b)
輸出的公式和預(yù)想的一致:
這時市袖,小小的修改一下代碼:
@latexify.with_latex
def test(a, b):
return 1 - (a + b)
預(yù)想的公式應(yīng)該為:
而實際卻是:
猜想,這可能是一個bug或者是輸入的方式不對烁涌,雖然這個問題很好解決苍碟,但是一直很疑惑。撮执。微峰。。抒钱。
latexify_py做了什么蜓肆?
為了一探究竟,嘗試去閱讀其源碼谋币,看看它都做了哪些事情仗扬?
首先入口是@latexify.with_latex這個注解。latexify提供with_latex和get_latex兩個注解蕾额,with_latex只是先做一些初始化早芭,實際也是調(diào)用get_latex。重點看一下get_latex凡简,其源碼:
def get_latex(fn, math_symbol=True):
try:
source = inspect.getsource(fn)##獲取整個模塊的源代碼
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)
return LatexifyVisitor(math_symbol=math_symbol).visit(ast.parse(source)) ##ast.parse把源碼解析為AST節(jié)點逼友,AST是抽象語法樹,不依賴于具體的文法秤涩,不依賴于語言的細節(jié)帜乞,我們將源代碼轉(zhuǎn)化為AST后,可以對AST做很多的操作
LatexifyVisitor繼承ast的NodeVisitor筐眷,ast.NodeVisitor是一個專門用來遍歷語法樹的工具黎烈,可以通過繼承這個類來完成對語法樹的遍歷以及遍歷過程中的處理。
LatexifyVisitor首先從根節(jié)點root進行遍歷匀谣,在遍歷的過程中照棋,每個節(jié)點類型都有專用的類型處理函數(shù),以"visit_" + "Node類型"為名稱武翎,如果不存在烈炭,則調(diào)用通用的的處理函數(shù)generic_visit。
在latexify的core.py直接引入astunparse宝恶,將生成的ast打印出來:
def get_latex(fn, math_symbol=True):
try:
source = inspect.getsource(fn)
print(astunparse.dump(ast.parse(source)))
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)
return LatexifyVisitor(math_symbol=math_symbol).visit(ast.parse(source))
下面是test對應(yīng)的ast結(jié)構(gòu):
Module(
body=[FunctionDef(
name='test',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='a',
annotation=None,
type_comment=None),
arg(
arg='b',
annotation=None,
type_comment=None)],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]),
body=[Return(value=BinOp(
left=Constant(
value=1,
kind=None),
op=Sub(),
right=BinOp(
left=Name(
id='a',
ctx=Load()),
op=Add(),
right=Name(
id='b',
ctx=Load()))))],
decorator_list=[Attribute(
value=Name(
id='latexify',
ctx=Load()),
attr='with_latex',
ctx=Load())],
returns=None,
type_comment=None)],
type_ignores=[])
首先訪問根節(jié)點root符隙,root為Moudle類型趴捅,會調(diào)用visit_Moudle函數(shù),以此始遍歷子節(jié)點FunctionDef霹疫、Return和BinOp拱绑,調(diào)用對應(yīng)的visit_FunctionDef、visit_Return和vist_BinOp丽蝎。
參照打印出來的python公式代碼和ast結(jié)構(gòu)猎拨,來分析一下整體邏輯:
vist_FunctionDef
def visit_FunctionDef(self, node):
name_str = r'\mathrm{' + str(node.name) + '}'
arg_strs = [self._parse_math_symbols(str(arg.arg)) for arg in node.args.args]
body_str = self.visit(node.body[0])
return name_str + '(' + ', '.join(arg_strs) + r')\triangleq ' + body_str
遍歷FunctionDef節(jié)點后,輸出為:
\mathrm{test}(a屠阻,b)\triangleq
visit_Return
def visit_Return(self, node):
return self.visit(node.value)
Return節(jié)點的值為子節(jié)點红省,類型為BinOp。ast將輸入的代碼分為left和right国觉,test例子中类腮,left為常數(shù)1,right是下一個子節(jié)點蛉加,類型為BinOp蚜枢,op為運算符,這里為Sub減法针饥〕С椋看看visit_BinOp:
visit_BinOp
def visit_BinOp(self, node):
priority = {
ast.Add: 10,
ast.Sub: 10,
ast.Mult: 20,
ast.MatMult: 20,
ast.Div: 20,
ast.FloorDiv: 20,
ast.Mod: 20,
ast.Pow: 30,
}
def _unwrap(child):
return self.visit(child)
def _wrap(child):
latex = _unwrap(child)
if isinstance(child, ast.BinOp):
cp = priority[type(child.op)] if type(child.op) in priority else 100
pp = priority[type(node.op)] if type(node.op) in priority else 100
if cp < pp:
return '(' + latex + ')'
return latex
l = node.left
r = node.right
reprs = {
ast.Add: (lambda: _wrap(l) + ' + ' + _wrap(r)),
ast.Sub: (lambda: _wrap(l) + ' - ' + _wrap(r)),
ast.Mult: (lambda: _wrap(l) + _wrap(r)),
ast.MatMult: (lambda: _wrap(l) + _wrap(r)),
ast.Div: (lambda: r'\frac{' + _unwrap(l) + '}{' + _unwrap(r) + '}'),
ast.FloorDiv: (lambda: r'\left\lfloor\frac{' + _unwrap(l) + '}{' + _unwrap(r) + r'}\right\rfloor'),
ast.Mod: (lambda: _wrap(l) + r' \bmod ' + _wrap(r)),
ast.Pow: (lambda: _wrap(l) + '^{' + _unwrap(r) + '}'),
}
if type(node.op) in reprs:
return reprs[type(node.op)]()
else:
return r'\mathrm{unknown\_binop}(' + _unwrap(l) + ', ' + _unwrap(r) + ')'
ast.Add和ast.Sub設(shè)置的優(yōu)先級都為10,_wrap方法通過優(yōu)先級來判斷是否添加括號丁眼,即:
cp = priority[type(child.op)] if type(child.op) in priority else 100
pp = priority[type(node.op)] if type(node.op) in priority else 100
if cp < pp:
return '(' + latex + ')'
test例子中child.op為Sub筷凤,node.op是right中的op為Add,優(yōu)先級相同不添加括號苞七,所以輸出:
1 - a + b
遍歷結(jié)束后輸出:
\mathrm{test}(a, b)\triangleq 1 - a + b
這和公式實際上表達的意思南轅北轍藐守,解決方法就是將小于改為小于等于,即
if cp <= pp:
return '(' + latex + ')'