在創(chuàng)建一個符合Python風(fēng)格的對象(1)中伪很,定義了一個二維向量 Vector2d
類,現(xiàn)在以該類為基礎(chǔ)杂数,繼續(xù)擴展宛畦,定義表示多維向量的Vector
類。
支持的功能如下:
- 基本的序列協(xié)議揍移,
__len__
和__getitem__
- 正確表述擁有很多元素的實例
- 適當(dāng)?shù)那衅С执魏停糜谏a(chǎn)新的
Vector
實例 - 綜合各個元素的值計算散列值
- 自定義的格式語言擴展
此外,通過 __getattr__
方法實現(xiàn)屬性的動態(tài)存取那伐,以此取代 Vector2d
使用的只讀特性——不過踏施,序列類型通常不會這么做石蔗。
下面來一步步實現(xiàn)。
1.為了支持N維向量畅形,讓構(gòu)造函數(shù)接受可迭代對象
def __init__(self, components):
# 把 Vector 的分量保存在一個數(shù)組中
self._components = array(self.typecode, components)
2.為了支持迭代养距,使用self.components
構(gòu)建一個迭代器
def __iter__(self):
return iter(self._components)
3.使用reprlib.repr()
函數(shù)獲取 self._components
的有限長度表示形式(如 array('d', [0.0, 1.0, 2.0, 3.0, 4.0, ...])
def __repr__(self):
components = reprlib.repr(self._components)
# 去掉前面的 array('d' 和后面的 )。
components = components[components.find('['):-1]
return 'Vector({})'.format(components)
4.直接使用self.components
構(gòu)建bytes
對象
def __bytes__(self):
return (bytes(ord([self.typecode])) + bytes(self._components))
5計算模
def __abs__(self):
"""計算各分量的平方之和日熬,然后再使用 sqrt 方法開平方"""
return math.sqrt(sum(x * x for x in self))
6.針對frombytes
棍厌,直接把 memoryview
傳給構(gòu)造方法,不用像前面那樣使用 *
拆包
@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:]).cast(typecode)
return cls(memv)
7.為了支持序列協(xié)議竖席,實現(xiàn)__len__
和__getitem__
方法
def __len__(self):
return len(self._components)
def __getitem__(self, index):
"""自定義切片操作"""
cls = type(self)
# 如果 index 參數(shù)的值是 slice 對象,調(diào)用類的構(gòu)造方法毕荐,使用 _components 數(shù)組的切片構(gòu)建一個新 Vector 實例
if isinstance(index, slice):
return cls(self._components[index])
# 如果 index 是 int 或其他整數(shù)類型,那就返回 _components 中相應(yīng)的元素
elif isinstance(index, numbers.Integral):
return self._components[index]
# 否則憎亚,拋出異常
else:
msg = '{.__name__} indices must be integers'
raise TypeError(msg.format(cls))
8.動態(tài)存取屬性
因為現(xiàn)在是N維向量第美,使用Vector2d
中獲取屬性的方式顯然太麻煩。
要想依舊使用my_obj.x
方式獲取屬性牲览,可以實現(xiàn)__getattr__
方法第献,因為屬性查找失敗后庸毫,解釋器會調(diào)用 __getattr__
方法飒赃。
# 定義幾個可以獲取的常用分量
shortcut__names = 'xyzt'
def __getattr__(self, name):
"""檢查所查找的屬性是不是 shortcut__names 中的某個字母载佳,如果是臀栈,那么返回對應(yīng)的分量。"""
cls = type(self)
# 如果屬性名只有一個字母姑躲,可能是shortcut_names 中的一個
if len(name) == 1:
# 找到所在位置
pos = cls.shortcut_names.find(name)
if 0 <= pos < len(self._components):
return self._components[pos]
msg = '{.__name__ !r} object has no attribute {!r}'
raise AttributeError(msg.format(cls, name))
但是僅僅實現(xiàn)這樣一個方法還不夠,需要注意到對于實例v
卖怜,如果執(zhí)行了v.x
命令阐枣,實際上v
對象就有x
屬性了侮繁,因此使用v.x
不會調(diào)用__getattr__
方法宪哩。
為了避免上述情況锁孟,需要改寫Vector類中設(shè)置屬性的邏輯品抽,通過自定義__setattr__
方法實現(xiàn)圆恤。
def __setattr__(self, name, value):
cls = type(self)
# 特別處理名稱是單個字符的屬性
if len(name) == 1:
# 如果 name 是 shortcut_names 中的一個盆昙,設(shè)置特殊的錯誤消息
if name in cls.shortcut_names:
error = 'readonly attribute {attr_name!r}'
# 如果 name 是小寫字母淡喜,為所有小寫字母設(shè)置一個錯誤消息
elif name.islower():
error = "can't set attributes 'a' to 'z' in {cls_name!r}"
# 否則炼团,把錯誤消息設(shè)為空字符串
else:
error = ''
# 如果有錯誤消息,拋出 AttributeError
if error:
msg = error.format(cls_name=cls.__name__, attr_name=name)
raise AttributeError(msg)
# 默認(rèn)情況:在超類上調(diào)用 __setattr__ 方法疏尿,提供標(biāo)準(zhǔn)行為
super().__setattr__(name, value)
在類中聲明 __slots__
屬性也可以防止設(shè)置新實例屬性。但是不建議只為了避免創(chuàng)建實例屬性而使用 __slots__
屬性褥琐。__slots__
屬性只應(yīng)該用于節(jié)省內(nèi)存,而且僅當(dāng)內(nèi)存嚴(yán)重不足時才應(yīng)該這么做踩衩。
另外嚼鹉,為了將該類實例變成是可散列的,需要保持Vector
是不可變的锚赤。
9.支持散列和快速等值測試
def __eq__(self, other):
# 首先要檢查兩個操作數(shù)的長度是否相同匹舞,因為 zip 函數(shù)會在最短的那個操作數(shù)耗盡時停止,而且不發(fā)出警告线脚。
# 然后再依次比較兩個序列中的每一個元素
return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
def __hash__(self):
# 創(chuàng)建一個生成器表達(dá)式,惰性計算各個分量的散列值
hashes = (hash(x) for x in self)
# 把 hashes 提供給 reduce 函數(shù)姊舵,使用 xor 函數(shù)計算聚合的散列值寓落;第三個參數(shù),0 是初始值
return functools.reduce(operator.xor, hashes, 0)
10.格式化
Vector
類支持 N 個維度仰税,所以這里使用球面坐標(biāo),格式后綴定義為'h'
。這里的難點主要是涉及數(shù)學(xué)原理河绽,理解意思即可拦赠。具體可以查看n 維球體
def angle(self, n):
"""使用公式計算某個角坐標(biāo)"""
r = math.sqrt(sum(x * x for x in self[n:]))
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and (self[-1] < 0):
return math.pi * 2 - a
else:
return a
def angles(self):
"""創(chuàng)建生成器表達(dá)式,按需計算所有角坐標(biāo)"""
return (self.angle(n) for n in range(1, len(self)))
def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'): # 超球面坐標(biāo)
fmt_spec = fmt_spec[:-1]
# 使用 itertools.chain 函數(shù)生成生成器表達(dá)式葵姥,無縫迭代向量的模和各個角坐標(biāo)
coords = itertools.chain([abs(self)], self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))
下面給出完整代碼
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools
class Vector:
typecode = 'd'
def __init__(self, components):
self._components = array(self.typecode, components)
def __iter__(self):
return iter(self._components)
def __repr__(self):
components = reprlib.repr(self._components)
components = components[components.find('['):-1]
return 'Vector({})'.format(components)
def __str__(self):
return str(tuple(self))
def __bytes__(self):
return (bytes([ord(self.typecode)]) +
bytes(self._components))
def __eq__(self, other):
return (len(self) == len(other) and
all(a == b for a, b in zip(self, other)))
def __hash__(self):
hashes = (hash(x) for x in self)
return functools.reduce(operator.xor, hashes, 0)
def __abs__(self):
return math.sqrt(sum(x * x for x in self))
def __bool__(self):
return bool(abs(self))
def __len__(self):
return len(self._components)
def __getitem__(self, index):
cls = type(self)
if isinstance(index, slice):
return cls(self._components[index])
elif isinstance(index, numbers.Integral):
return self._components[index]
else:
msg = '{.__name__} indices must be integers'
raise TypeError(msg.format(cls))
shortcut_names = 'xyzt'
def __getattr__(self, name):
cls = type(self)
if len(name) == 1:
pos = cls.shortcut_names.find(name)
if 0 <= pos < len(self._components):
return self._components[pos]
msg = '{.__name__!r} object has no attribute {!r}'
raise AttributeError(msg.format(cls, name))
def angle(self, n):
r = math.sqrt(sum(x * x for x in self[n:]))
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and (self[-1] < 0):
return math.pi * 2 - a
else:
return a
def angles(self):
return (self.angle(n) for n in range(1, len(self)))
def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'): # 超球面坐標(biāo)
fmt_spec = fmt_spec[:-1]
coords = itertools.chain([abs(self)],
self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))
@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:]).cast(typecode)
return cls(memv)