1. 程式人生 > >Golang 實現 Redis(6): 實現 pipeline 模式的 redis 客戶端

Golang 實現 Redis(6): 實現 pipeline 模式的 redis 客戶端

本文是使用 golang 實現 redis 系列的第六篇, 將介紹如何實現一個 Pipeline 模式的 Redis 客戶端。 本文的完整程式碼在[Github:Godis/redis/client](https://github.com/HDT3213/godis/blob/master/src/redis/client/client.go) 通常 TCP 客戶端的通訊模式都是阻塞式的: 客戶端傳送請求 -> 等待服務端響應 -> 傳送下一個請求。因為需要等待網路傳輸資料,完成一次請求迴圈需要等待較多時間。 我們能否不等待服務端響應直接傳送下一條請求呢?答案是肯定的。 TCP 作為全雙工協議可以同時進行上行和下行通訊,不必擔心客戶端和服務端同時發包會導致衝突。 > p.s. 打電話的時候兩個人同時講話就會衝突聽不清,只能輪流講。這種通訊方式稱為半雙工。廣播只能由電臺傳送到收音機不能反向傳輸,這種方式稱為單工。 我們為每一個 tcp 連線分配了一個 goroutine 可以保證先收到的請求先先回復。另一個方面,tcp 協議會保證資料流的有序性,同一個 tcp 連線上先發送的請求服務端先接收,先回復的響應客戶端先收到。因此我們不必擔心混淆響應所對應的請求。 這種在服務端未響應時客戶端繼續向服務端傳送請求的模式稱為 Pipeline 模式。因為減少等待網路傳輸的時間,Pipeline 模式可以極大的提高吞吐量,減少所需使用的 tcp 連結數。 pipeline 模式的 redis 客戶端需要有兩個後臺協程程負責 tcp 通訊,呼叫方通過 channel 向後臺協程傳送指令,並阻塞等待直到收到響應,這是一個典型的非同步程式設計模式。 我們先來定義 client 的結構: ```golang type Client struct { conn net.Conn // 與服務端的 tcp 連線 sendingReqs chan *Request // 等待發送的請求 waitingReqs chan *Request // 等待伺服器響應的請求 ticker *time.Ticker // 用於觸發心跳包的計時器 addr string ctx context.Context cancelFunc context.CancelFunc writing *sync.WaitGroup // 有請求正在處理不能立即停止,用於實現 graceful shutdown } type Request struct { id uint64 // 請求id args [][]byte // 上行引數 reply redis.Reply // 收到的返回值 heartbeat bool // 標記是否是心跳請求 waiting *wait.Wait // 呼叫協程傳送請求後通過 waitgroup 等待請求非同步處理完成 err error } ``` 呼叫者將請求傳送給後臺協程,並通過 wait group 等待非同步處理完成: ```golang func (client *Client) Send(args [][]byte) redis.Reply { request := &Request{ args: args, heartbeat: false, waiting: &wait.Wait{}, } request.waiting.Add(1) client.sendingReqs <- request // 將請求發往處理佇列 timeout := request.waiting.WaitWithTimeout(maxWait) // 等待請求處理完成或者超時 if timeout { return reply.MakeErrReply("server time out") } if request.err != nil { return reply.MakeErrReply("request failed: " + err.Error()) } return request.reply } ``` client 的核心部分是後臺的讀寫協程。先從寫協程開始: ```golang // 寫協程入口 func (client *Client) handleWrite() { loop: for { select { case req := <-client.sendingReqs: // 從 channel 中取出請求 client.writing.Add(1) // 未完成請求數+1 client.doRequest(req) // 傳送請求 case <-client.ctx.Done(): break loop } } } // 傳送請求 func (client *Client) doRequest(req *Request) { bytes := reply.MakeMultiBulkReply(req.args).ToBytes() // 序列化 _, err := client.conn.Write(bytes) // 通過 tcp connection 傳送 i := 0 for err != nil && i < 3 { // 失敗重試 err = client.handleConnectionError(err) if err == nil { _, err = client.conn.Write(bytes) } i++ } if err == nil { client.waitingReqs <- req // 將傳送成功的請求放入等待響應的佇列 } else { // 傳送失敗 req.err = err req.waiting.Done() // 結束呼叫者的等待 client.writing.Done() // 未完成請求數 -1 } } ``` 讀協程是我們熟悉的協議解析器模板, 不熟悉的朋友可以到[實現 Redis 協議解析器](https://www.cnblogs.com/Finley/p/11923168.html)瞭解更多。 ```golang // 收到服務端的響應 func (client *Client) finishRequest(reply redis.Reply) { request := <-client.waitingReqs // 取出等待響應的 request request.reply = reply if request.waiting != nil { request.waiting.Done() // 結束呼叫者的等待 } client.writing.Done() // 未完成請求數-1 } // 讀協程是個 RESP 協議解析器,不熟悉的朋友可以 func (client *Client) handleRead() error { reader := bufio.NewReader(client.conn) downloading := false expectedArgsCount := 0 receivedCount := 0 msgType := byte(0) // first char of msg var args [][]byte var fixedLen int64 = 0 var err error var msg []byte for { // read line if fixedLen == 0 { // read normal line msg, err = reader.ReadBytes('\n') if err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { logger.Info("connection close") } else { logger.Warn(err) } return errors.New("connection closed") } if len(msg) == 0 || msg[len(msg)-2] != '\r' { return errors.New("protocol error") } } else { // read bulk line (binary safe) msg = make([]byte, fixedLen+2) _, err = io.ReadFull(reader, msg) if err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { return errors.New("connection closed") } else { return err } } if len(msg) == 0 || msg[len(msg)-2] != '\r' || msg[len(msg)-1] != '\n' { return errors.New("protocol error") } fixedLen = 0 } // parse line if !downloading { // receive new response if msg[0] == '*' { // multi bulk response // bulk multi msg expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) if err != nil { return errors.New("protocol error: " + err.Error()) } if expectedLine == 0 { client.finishRequest(&reply.EmptyMultiBulkReply{}) } else if expectedLine >
0 { msgType = msg[0] downloading = true expectedArgsCount = int(expectedLine) receivedCount = 0 args = make([][]byte, expectedLine) } else { return errors.New("protocol error") } } else if msg[0] == '$' { // bulk response fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) if err != nil { return err } if fixedLen == -1 { // null bulk client.finishRequest(&reply.NullBulkReply{}) fixedLen = 0 } else if fixedLen >
0 { msgType = msg[0] downloading = true expectedArgsCount = 1 receivedCount = 0 args = make([][]byte, 1) } else { return errors.New("protocol error") } } else { // single line response str := strings.TrimSuffix(string(msg), "\n") str = strings.TrimSuffix(str, "\r") var result redis.Reply switch msg[0] { case '+': result = reply.MakeStatusReply(str[1:]) case '-': result = reply.MakeErrReply(str[1:]) case ':': val, err := strconv.ParseInt(str[1:], 10, 64) if err != nil { return errors.New("protocol error") } result = reply.MakeIntReply(val) } client.finishRequest(result) } } else { // receive following part of a request line := msg[0 : len(msg)-2] if line[0] == '$' { fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { return err } if fixedLen <= 0 { // null bulk in multi bulks args[receivedCount] = []byte{} receivedCount++ fixedLen = 0 } } else { args[receivedCount] = line receivedCount++ } // if sending finished if receivedCount == expectedArgsCount { downloading = false // finish downloading progress if msgType == '*' { reply := reply.MakeMultiBulkReply(args) client.finishRequest(reply) } else if msgType == '$' { reply := reply.MakeBulkReply(args[0]) client.finishRequest(reply) } // finish reply expectedArgsCount = 0 receivedCount = 0 args = nil msgType = byte(0) } } } } ``` 最後編寫 client 的構造器和啟動非同步協程的程式碼: ```golang func MakeClient(addr string) (*Client, error) { conn, err := net.Dial("tcp", addr) if err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) return &Client{ addr: addr, conn: conn, sendingReqs: make(chan *Request, chanSize), waitingReqs: make(chan *Request, chanSize), ctx: ctx, cancelFunc: cancel, writing: &sync.WaitGroup{}, }, nil } func (client *Client) Start() { client.ticker = time.NewTicker(10 * time.Second) go client.handleWrite() go func() { err := client.handleRead() logger.Warn(err) }() go client.heartbeat() } ``` 關閉 client 的時候記得等待請求完成: ```golang func (client *Client) Close() { // 先阻止新請求進入佇列 close(client.sendingReqs) // 等待處理中的請求完成 client.writing.Wait() // 釋放資源 _ = client.conn.Close() // 關閉與服務端的連線,連線關閉後讀協程會退出 client.cancelFunc() // 使用 context 關閉讀協程 close(client.waitingReqs) // 關閉佇列 } ``` 測試一下: ```golang func TestClient(t *testing.T) { client, err := MakeClient("localhost:6379") if err != nil { t.Error(err) } client.Start() result = client.Send([][]byte{ []byte("SET"), []byte("a"), []byte("a"), }) if statusRet, ok := result.(*reply.StatusReply); ok { if statusRet.Status != "OK" { t.Error("`set` failed, result: " + statusRet.Status) } } result = client.Send([][]byte{ []byte("GET"), []byte("a"), }) if bulkRet, ok := result.(*reply.BulkReply); ok { if string(bulkRet.Arg) != "a" { t.Error("`get` failed, result: " + string(bulkRet.Arg)) } } } ```