1. 程式人生 > >Golang實現Unix Socket

Golang實現Unix Socket

我們都知道socket本來是為網路通訊設計,可以通過socket方便的實現不同機器之間的通訊。當然通過socket也可以實現同一臺主機之間的程序通訊。但是通過socket網路協議實現程序通訊效率太低,後來就出現了IPC通訊協議,UNIX Domain Socket (UDS)就是其中之一,而且用於IPC更有效率,比如:資料不再需要經過網路協議棧,也不需要進行打包拆包、計算校驗和、維護序列號和及實現應答機制等,要做的只是將應用層資料從一個程序拷貝到另一個程序。

類似於Socket的TCP和UDP,UNIX Domain Socket也提供面向流和麵向資料包兩種實現,但是面向訊息的UNIX Domain Socket也是可靠的,訊息既不會丟失也不會順序錯亂。UNIX Domain Socket是全雙工的,在實現的時候可以明顯的發現客戶端與服務端通訊是通過檔案讀寫進行的通訊,而且可以同時讀寫。相比其它IPC機制有明顯的優越性,目前已成為使用最廣泛的IPC機制,如果用過supervisor的同學會發現在配置supervisor的時候需要設定一個supervisor.sock的問題地址,不難猜出,supervisor即使用unix socket來進行通訊的。

使用UNIX Domain Socket的過程和網路socket十分相似,也要先呼叫socket()建立一個socket檔案描述符,address family指定為AF_UNIX,type可以選擇SOCK_DGRAM或SOCK_STREAM,protocol引數仍然指定為0即可。但是我們通過Golang或者其他高階語言實現的時候,會掩蓋很多底層資訊。

UNIX Domain Socket與網路socket程式設計最明顯的不同在於地址格式不同,用結構體sockaddr_un表示,網路程式設計的socket地址是IP地址加埠號,而UNIX Domain Socket的地址是一個socket型別的檔案在檔案系統中的路徑,這個socket檔案由bind()呼叫建立,如果呼叫bind()時該檔案已存在,則bind()錯誤返回。所以在實現的時候,可以再啟動的時候刪掉sock檔案,也可以在程式的signal捕獲到退出時刪除sock檔案。

我們定義Client 程式碼和Server端程式碼,方便測試。分別build,然後先啟動Server,再啟動Client即可。

Server端程式碼如下:

package main

import (
	"encoding/binary"
	"bytes"
	"io"
	"os"
	"fmt"
	"net"
	"time"
)

const (
	UNIX_SOCK_PIPE_PATH = "/var/run/unixsock_test.sock" // socket file path
)

func main() {
	// Remove socket file
	os.Remove(UNIX_SOCK_PIPE_PATH)
	// Get unix socket address based on file path
	uaddr, err := net.ResolveUnixAddr("unix", UNIX_SOCK_PIPE_PATH)
	if err != nil {
		fmt.Println(err)
		return
	}

	// Listen on the address
	unixListener, err := net.ListenUnix("unix", uaddr)
	if err != nil {
		fmt.Println(err)
		return
	}

	// Close listener when close this function, you can also emit it because this function will not terminate gracefully
	defer unixListener.Close()

	fmt.Println("Waiting for asking questions ...")
	// Monitor request and process
	for {
		uconn, err := unixListener.AcceptUnix()
		if err != nil {
			fmt.Println(err)
			continue
		}

		// Handle request
		go handleConnection(uconn)
	}
}

/*******************************************************
* Handle connection and request
* conn: conn handler
*******************************************************/
func handleConnection(conn *net.UnixConn) {
	// Close connection when finish handling
	defer func() {
		conn.Close()
	}()

	// Read data and return response
	data, err := parseRequest(conn)
	if err != nil {
		fmt.Println(err)
		return
	}

	fmt.Printf("%+v\tReceived from client: %s\n", time.Now(), string(data))
	time.Sleep(time.Duration(1) * time.Second) // sleep to simulate request process

	// Send back response
	sendResponse(conn, []byte(time.Now().String()))
}

/*******************************************************
* Parse request of unix socket
* conn: conn handler
*******************************************************/
func parseRequest(conn *net.UnixConn) ([]byte, error) {
	var reqLen uint32
	lenBytes := make([]byte, 4)
	if _, err := io.ReadFull(conn, lenBytes); err != nil {
		return nil, err
	}

	lenBuf := bytes.NewBuffer(lenBytes)
	if err := binary.Read(lenBuf, binary.BigEndian, &reqLen); err != nil {
		return nil, err
	}

	reqBytes := make([]byte, reqLen)
	_, err := io.ReadFull(conn, reqBytes)

	if err != nil {
		return nil, err
	}

	return reqBytes, nil
}

/*******************************************************
* Send response to client
* conn: conn handler
*******************************************************/
func sendResponse(conn *net.UnixConn, data []byte) {
	buf := new(bytes.Buffer)
	msglen := uint32(len(data))

	binary.Write(buf, binary.BigEndian, &msglen)
	data = append(buf.Bytes(), data...)

	conn.Write(data)
}

Client端程式碼如下:

package main

import (
	"time"
	"io"
	"encoding/binary"
	"bytes"
	"fmt"
	"net"
)

const (
	UNIX_SOCK_PIPE_PATH = "/var/run/unixsock_test.sock" // socket file path
)

var (
	exitSemaphore chan bool
)

func main() {
	// Get unix socket address based on file path
	uaddr, err := net.ResolveUnixAddr("unix", UNIX_SOCK_PIPE_PATH)
	if err != nil {
		fmt.Println(err)
		return
	}

	// Connect server with unix socket
	uconn, err := net.DialUnix("unix", nil, uaddr)
	if err != nil {
		fmt.Println(err)
		return
	}

	// Close unix socket when exit this function
	defer uconn.Close()
	
	// Wait to receive response
	go onMessageReceived(uconn)

	// Send a request to server
	// you can define your own rules
	msg := "tell me current time\n"
	sendRequest(uconn, []byte(msg))

	// Wait server response
	// change this duration bigger than server sleep time to get correct response
	exitSemaphore = make(chan bool)
	select {
	case <- time.After(time.Duration(2) * time.Second):
		fmt.Println("Wait response timeout")
	case <-exitSemaphore:
		fmt.Println("Get response correctly")
	}

	close(exitSemaphore)
}

/*******************************************************
* Send request to server, you can define your own proxy
* conn: conn handler
*******************************************************/
func sendRequest(conn *net.UnixConn, data []byte) {
	buf := new(bytes.Buffer)
	msglen := uint32(len(data))

	binary.Write(buf, binary.BigEndian, &msglen)
	data = append(buf.Bytes(), data...)

	conn.Write(data)
}

/*******************************************************
* Handle connection and response
* conn: conn handler
*******************************************************/
func onMessageReceived(conn *net.UnixConn) {
	//for { // io Read will wait here, we don't need for loop to check
		// Read information from response
		data, err := parseResponse(conn)
		if err != nil {
			fmt.Println(err)
		} else {
			fmt.Printf("%v\tReceived from server: %s\n", time.Now(), string(data))
		}

		// Exit when receive data from server
		exitSemaphore <- true
	//}
}

/*******************************************************
* Parse request of unix socket
* conn: conn handler
*******************************************************/
func parseResponse(conn *net.UnixConn) ([]byte, error) {
	var reqLen uint32
	lenBytes := make([]byte, 4)
	if _, err := io.ReadFull(conn, lenBytes); err != nil {
		return nil, err
	}

	lenBuf := bytes.NewBuffer(lenBytes)
	if err := binary.Read(lenBuf, binary.BigEndian, &reqLen); err != nil {
		return nil, err
	}

	reqBytes := make([]byte, reqLen)
	_, err := io.ReadFull(conn, reqBytes)

	if err != nil {
		return nil, err
	}

	return reqBytes, nil
}