python 生成器和fire
阿新 • • 發佈:2020-08-03
class HyperParameters(object):
"""
用於管理模型超引數
"""
def __init__(
self,
max_length: int = 128,
epochs=4,
batch_size=32,
learning_rate=2e-5,
fp16=True,
fp16_opt_level="O1",
max_grad_norm=1.0,
warmup_steps=0.1,
) -> None:
self.max_length = max_length
self.epochs = epochs
"""訓練迭代輪數"""
self.batch_size = batch_size
"""每個batch的樣本數量"""
self.learning_rate = learning_rate
"""學習率"""
self.fp16 = fp16
"""是否使用fp16混合精度訓練"""
self.fp16_opt_level = fp16_opt_level
"""用於fp16,Apex AMP優化等級,['O0', 'O1', 'O2', and 'O3']可選,詳見https://nvidia.github.io/apex/amp.html"""
"""最大梯度裁剪"""
self.warmup_steps = warmup_steps
"""學習率線性預熱步數"""
-------引數
def __repr__(self) -> str:
return self.__dict__.__repr__()
HyperParameters()
-----------------生成器
def gen(sth):
for _ in ["a", "b", "c"]:
yield _
lst = gen("")
next(lst)
------------------fire
import fire
def hello_word(name, time):
print(time, "hello", name)
if __name__ == '__main__':
fire.Fire(hello_word)
python XX.py --name s --time now