1. 程式人生 > 其它 >Go語言學習: 實現輕量執行緒池

Go語言學習: 實現輕量執行緒池

Goroutine 池的核心思想是對 Goroutine 的重用,也就是把 M 個計算任務排程到 N 個 Goroutine 上,而不是為每個計算任務分配一個獨享的 Goroutine,從而提高計算資源的利用率。

這裡簡單地採用channel+select的實現方案,主要可分成3部分:

  1. pool的建立與銷燬

  2. pool中的worker(Goroutine)的管理

  3. task的提交與排程

定義一個結構體Pool,它應當具有一些屬性:

capacity 是 pool 的一個屬性,代表整個 pool 中 worker 的最大容量。使用一個帶緩衝的 channel:active,作為 worker 的“計數器”,這種 channel 使用模式就是計數訊號量,那麼其對應的資料型別就是struct{}。

當 active channel 可寫時,就建立一個 worker,用於處理使用者通過 Schedule 函式提交的待處理的請求。當 active channel 滿了的時候,pool 就會停止 worker 的建立,直到某個 worker 因故退出,active channel 又空出一個位置時,pool 才會建立新的 worker 填補那個空位。

把使用者要提交給 workerpool 執行的請求抽象為一個 Task。Task 的提交與排程也很簡單:Task 通過 Schedule 函式提交到一個 task channel 中,已經建立的 worker 將從這個 task channel 中讀取 task 並執行。

定義瞭如下結構體:

type Pool struct {
    capacity int         // workerpool大小

    active chan struct{} // 對應上圖中的active channel
    tasks  chan Task     // 對應上圖中的task channel

    wg   sync.WaitGroup  // 用於在pool銷燬時等待所有worker退出
    quit chan struct{}   // 用於通知各個worker退出的訊號channel
}

下面實現一個New函式,用於建立一個pool型別例項,並將pool池的worker管理機制執行起來。

func New(capacity int) *Pool {
    if capacity <= 0 {
        capacity = defaultCapacity
    }
    if capacity > maxCapacity {
        capacity = maxCapacity
    }

    p := &Pool{
        capacity: capacity,
        tasks:    make(chan Task),
        quit:     make(chan struct{}),
        active:   make(chan struct{}, capacity),
    }
    fmt.Printf("workpool start\n")
    go p.run()
    return p
}

New函式接收一個引數capacity用於指定workerpool池的容量,這個引數用於控制wokerpool最多隻能有capacity個worker,共同處理使用者提交的任務請求。函式開始處會檢查傳參是否合理。

Pool型別例項變數p完成初始化後,建立一個新的Goroutine,用於workerpool進行管理,這個Goroutine用於對workerpool進行管理,這個goroutine執行的是pool型別的run方法。

func (p *Pool) run() {
    index := 0
    for {
        select {
        case <-p.quit:
            return
        case p.active <- struct{}{}:
            index++
            p.newWorker(index)
        }
    }
}

run方法內是一個無限迴圈,迴圈體中使用select監視pool型別例項的兩個channel:quit和active。這種在for迴圈中使用select監視多個channel的實現,在Go程式碼中十分常見。

當接收到來自quit channel的退出"訊號"時,這個Goroutine就會結束執行。而當active channel可寫時,run方法就會建立一個新的worker Goroutine。此外,為了方便在程式中區分各個worker輸出的日誌,這裡將一個從1開始的變數index作為worker的編號,並將它以引數形式傳給建立worker的方法。

將建立新的 worker goroutine 的職責,封裝到一個名為 newWorker 的方法中:

func (p *Pool) newWorker(i int) {
    p.wg.Add(1)
    go func() {
        defer func() {
            if err := recover(); err != nil {
                fmt.Printf("worker[%03d]: recover panic[%s] and exit\n", i, err)
                <-p.active
            }
            p.wg.Done()
        }()

        fmt.Printf("worker[%03d]: start\n", i)

        for {
            select {
            case <-p.quit:
                fmt.Printf("worker[%03d]: exit\n", i)
                <-p.active
                return
            case t := <-p.tasks:
                fmt.Printf("worker[%03d]: receive a task\n", i)
                t()
            }
        }
    }()
}

在建立一個新的 worker goroutine 之前,newWorker 方法會先呼叫 p.wg.Add 方法將 WaitGroup 的等待計數加一。由於每個 worker 運行於一個獨立的 Goroutine 中,newWorker 方法通過 go 關鍵字建立了一個新的 Goroutine 作為 worker。

新 worker 的核心,依然是一個基於 for-select 模式的迴圈語句,在迴圈體中,新 worker 通過 select 監視 quit 和 tasks 兩個 channel。和前面的 run 方法一樣,當接收到來自 quit channel 的退出“訊號”時,這個 worker 就會結束執行。tasks channel 中放置的是使用者通過 Schedule 方法提交的請求,新 worker 會從這個 channel 中獲取最新的 Task 並執行這個 Task。

Task 是一個對使用者提交的請求的抽象,它的本質就是一個函式型別:

type Task func()

在新 worker 中,為了防止使用者提交的 task 丟擲 panic,進而導致整個 workerpool 受到影響,在 worker 程式碼的開始處,使用了 defer+recover 對 panic 進行捕捉,捕捉後 worker 也是要退出的,於是還通過<-p.active更新了 worker 計數器。並且一旦 worker goroutine 退出,p.wg.Done 也需要被呼叫,這樣可以減少 WaitGroup 的 Goroutine 等待數量。

workerpool 提供給使用者提交請求的匯出方法 Schedule:

var ErrWorkerPoolFreed    = errors.New("workerpool freed")       // workerpool已終止執行

func (p *Pool) Schedule(t Task) error {
    select {
    case <-p.quit:
        return ErrWorkerPoolFreed
    case p.tasks <- t:
        return nil
    }
}

這裡要注意的是,這裡的 Pool 結構體中的 tasks 是一個無緩衝的 channel,如果 pool 中 worker 數量已達上限,而且 worker 都在處理 task 的狀態,那麼 Schedule 方法就會阻塞,直到有 worker 變為 idle 狀態來讀取 tasks channel,schedule 的呼叫阻塞才會解除。

完整程式碼:

package main

import (
	"errors"
	"fmt"
	"sync"
	"time"
)

type Task func()

const (
	maxCapacity     = 20
	defaultCapacity = 10
)

type Pool struct {
	capacity int
	wg       sync.WaitGroup
	active   chan struct{}
	quit     chan struct{}
	tasks    chan Task
}

func New(capacity int) *Pool {
	if capacity <= 0 {
		capacity = defaultCapacity
	}
	if capacity > maxCapacity {
		capacity = maxCapacity
	}
	p := &Pool{
		capacity: capacity,
		active:   make(chan struct{}, capacity),
		quit:     make(chan struct{}),
		tasks:    make(chan Task),
	}
	go p.run()
	return p
}

func (p *Pool) run() {
	index := 0
	for {
		select {
		case <-p.quit:
			return
		case p.active <- struct{}{}:
			index++
			p.newWorker(index)
		}
	}
}

func (p *Pool) newWorker(index int) {
	p.wg.Add(1)
	go func() {
		defer func() {
			if err := recover(); err != nil {
				fmt.Printf("worker[%d] panic[%s] and exit\n", index, err)
				<-p.active
			}
			p.wg.Done()
		}()

		fmt.Printf("worker[%03d] start\n", index)
		for {
			select {
			case <-p.quit:
				fmt.Printf("worker[%03d] exit\n", index)
				<-p.active
				return
			case t := <-p.tasks:
				fmt.Printf("worker[%03d] received a task\n", index)
				t()
			}
		}
	}()
}

var ErrWorkerPoolFreed = errors.New("workerpool freed")

func (p *Pool) Schedule(t Task) error {
	select {
	case <-p.quit:
		return ErrWorkerPoolFreed
	case p.tasks <- t:
		return nil
	}
}

func (p *Pool) Free() {
	close(p.quit)
	p.wg.Wait()
	fmt.Println("workpool freed")
}

func main() {
	p := New(5)
	for i := 0; i < 10; i++ {
		err := p.Schedule(func() {
			time.Sleep(time.Second * 3)
		})
		if err != nil {
			println("task: ", i, "err:", err)
		}
	}
	p.Free()
}