1. 程式人生 > >[05] 通過P/Invoke加速C#程式

[05] 通過P/Invoke加速C#程式

通過P/Invoke加速C#程式

任何語言都會提供FFI機制(Foreign Function Interface, 叫法不太一樣), 大多數的FFI機制是和C API. C#提供了P/Invoke來和作業系統, 第三方擴充套件進行互動.

FFI通常用來和老的程式碼互動, 例如有大量的遺留程式碼, 重寫成本太高, 可以匯出C介面, 然後新系統和老系統互動; 還有一種用處就是優化, 將某一部分功能挪到C/C++(或者其他Native語言)裡面, 通過特殊的優化, 對系統進行加速.

所有的FFI均存在額外的開銷, 除了C++和C這種語言互動. 託管語言和非託管語言, 託管語言和託管語言互動的成本都不小. C#和C的互動, 主要的成本有兩塊:

  • 引數傳遞的成本

    C#裡面的字串是UTF-16編碼的, 但是在C裡面一般使用ASCII或者相容的編碼, 所以呼叫之前需要先做一次轉換.

    記憶體佈局不一樣的引數, 會有額外的開銷.

  • 呼叫的額外開銷

    P/Invoke 的開銷介於每個呼叫 10 到 30 x86 指令之間。 除了此固定成本外,封送還會產生額外的開銷。 在託管程式碼和非託管程式碼中具有相同的表示形式的可宣告型別之間沒有封送成本。 例如,int 和 Int32 之間沒有翻譯費用。

    可以理解為10-30個時鐘週期, 比虛擬函式呼叫成本要高一些.

以上是P/Invoke優化的基礎知識. 只要呼叫的函式執行的時間較長, 引數的轉換足夠少, 那麼進行P/Invoke優化就是有意義的.

某遊戲伺服器使用了AES-ECB加密演算法進行通訊協議的加密. 演算法一直沒改, 實現修改了好幾次, 因為整個編碼過程中, 會產生多個臨時byte[]物件, 所以一直想要優化掉.

下面這個版本是C# Slice的版本, 希望把加密後的內容放到我準備好的Slice裡面(IByteBuffer). 但是其中有一個MemoryStream還是無法處理, 這個物件內部還是會產生byte[].

public static int AesEncrypt(byte[] src, int offset, int count, byte[] dest, int destOffset, byte[] Key0)
{
    using Rijndael rm = Rijndael.Create();
    rm.Key = Key0;
    rm.Mode = CipherMode.ECB;
    rm.Padding = PaddingMode.PKCS7;

    using ICryptoTransform cTransform = rm.CreateEncryptor();
    using var memoryStream = new MemoryStream(dest, destOffset, count + 32);
    using var writer = new CryptoStream(memoryStream, cTransform, CryptoStreamMode.Write);
    writer.Write(src, offset, count);
    writer.FlushFinalBlock();

    return (int)memoryStream.Position;
}

花了好長時間去研究.NET內部的實現, 沒找到解決辦法.

所以這時候就把眼睛轉向了P/Invoke和C++. 好在可以先通過C#的版本生成一個輸入輸出樣本, 然後C++嘗試著去跑通整個輸入輸出.

下面是C++的版本:

aes_ech.h

#pragma once
#include <openssl/aes.h>
#include <assert.h>
#include <string.h>

#ifdef WIN32
#define __DLLIMPORT __declspec(dllimport)
#define __DLLEXPORT __declspec(dllexport)
#else
#define __DLLIMPORT
#define __DLLEXPORT 
#endif

extern "C"
{
__DLLEXPORT int AesEcbEncrypt(unsigned char* key, int key_size,
		unsigned char* source, int source_length,
		unsigned char* dest);

__DLLEXPORT int AesEcbDecrypt(unsigned char* key, int key_size,
		unsigned char* source, int source_length,
		unsigned char* dest);
}


static inline int pkcs7padding(unsigned char* data, int length) {
	int padding = AES_BLOCK_SIZE - length % AES_BLOCK_SIZE;
	int destSize = length + padding;
	for (int index = length; index < destSize; ++index) {
		data[index] = padding;
	}
	return destSize;
}

static inline int Encrypt(unsigned char* key, int keyLength,
			unsigned char* src, int srcLength,
			unsigned char* dest) {
	int paddingLength = pkcs7padding(src, srcLength);

	AES_KEY aes_key;
	AES_set_encrypt_key(reinterpret_cast<const unsigned char*>(&key[0]),
		keyLength * 8, &aes_key);

	unsigned char* encrypted = dest;

	for (int block = 0; block < paddingLength; block += AES_BLOCK_SIZE) {
		AES_ecb_encrypt(reinterpret_cast<const unsigned char*>(&src[block]),
			reinterpret_cast<unsigned char*>(&encrypted[block]),
			&aes_key, AES_ENCRYPT);
	}

	return paddingLength;
}

static inline int pkcs7unpadding(unsigned char* data, int dataLength) {
	int padding = data[dataLength - 1];
	return dataLength - padding;
}

static inline int Decrypt(unsigned char *key, int keyLength,
			unsigned char* encrypted, int encryptedLength,
			unsigned char* decrypted) {
	AES_KEY aes_key;
	AES_set_decrypt_key(reinterpret_cast<const unsigned char*>(&key[0]),
		keyLength * 8, &aes_key);

	int decrypted_length = encryptedLength;

	for (int block = 0; block < encryptedLength;
		block += AES_BLOCK_SIZE) {
		AES_ecb_encrypt(reinterpret_cast<const unsigned char*>(&encrypted[block]),
			reinterpret_cast<unsigned char*>(&decrypted[block]),
			&aes_key, AES_DECRYPT);
	}

	return pkcs7unpadding(decrypted, encryptedLength);
}

aes_ecb.cpp

#include "aes_ecb.h"

extern "C" 
{
__DLLEXPORT int AesEcbEncrypt(unsigned char* key, int key_size,
    unsigned char* source, int source_length,
    unsigned char* dest) {
    return ::Encrypt(key, key_size, source, source_length, dest);
}

__DLLEXPORT int AesEcbDecrypt(unsigned char* key, int key_size,
    unsigned char* source, int source_length,
    unsigned char* dest) {
    return ::Decrypt(key, key_size, source, source_length, dest);
}
}

C#的P/Invoke封裝, 以及測試程式碼:

using System;
using System.Runtime.InteropServices;
using System.Text;

namespace AesPInvoke
{
    static class AesWin
    {
        [DllImport("AESECB.dll", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbEncrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);

        [DllImport("AESECB.dll", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbDecrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);
    }
    static class AesLinux 
    {
        [DllImport("AESECB.so", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbEncrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);

        [DllImport("AESECB.so", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbDecrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);
    }

    static class Aes 
    {
        public unsafe delegate int AesFunc(byte* key, int key_size, byte* source, int source_length, byte* dest);
        static AesFunc encrypt;
        static AesFunc decrypt;
        public static AesFunc AesEncrpt => encrypt;
        public static AesFunc AesDecrypt => decrypt;
        static unsafe Aes() 
        {
            if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) 
            {
                encrypt = AesLinux.AesEcbEncrypt;
                decrypt = AesLinux.AesEcbDecrypt;
            }
            else 
            {
                encrypt = AesWin.AesEcbEncrypt;
                decrypt = AesWin.AesEcbDecrypt;
            }
        }
    }

    class Program
    {

        private static bool Compare(ArraySegment<byte> a, ArraySegment<byte> b) 
        {
            if (a.Count != b.Count)
            {
                return false;
            }
            for (int i = 0; i < a.Count; ++i) 
            {
                if (a[i] != b[i]) return false;
            }
            return true;
        }

        static  unsafe void Main(string[] args)
        {
            byte[] origin = new byte[] {
                                    0x06, 0x04, 0x34, 0x35, 0x32, 0x56, 0x0a, 0x10, 0x08, 0xf9, 0xeb, 0x06,
                                    0x10, 0x93, 0x12, 0x18, 0x85, 0x1a, 0x20, 0x89, 0xdf, 0xf6, 0xd3, 0x01
            };
            byte[] dest = new byte[] {0x0f, 0xd9, 0x52, 0x10, 0x11, 0x4b, 0xcc, 0xe5,
                              0x48, 0x9d, 0x47, 0x2a, 0x69, 0xa4, 0x19, 0xcc,
                              0x08, 0x6b, 0x7d, 0xe9, 0x65, 0x26, 0x53, 0x10,
                              0x5c, 0xc9, 0x2f, 0xa8, 0x02, 0x43, 0x32, 0x8f};

            var originSegment = new ArraySegment<byte>(origin);
            var destSegment = new ArraySegment<byte>(dest);

            byte[] key = Encoding.UTF8.GetBytes("12345678876543211234567887654abc");

            byte[] input = new byte[origin.Length + 32];
            Array.Copy(origin, input, origin.Length);

            byte[] output = new byte[origin.Length + 32];

            fixed(byte* keyPointer = key) 
            fixed(byte* inputPointer = input)
            fixed(byte* outputPointer = output)
            {
                var length = Aes.AesEncrpt(keyPointer, key.Length, inputPointer, origin.Length, outputPointer);
                var data = new ArraySegment<byte>(output, 0, length);
                Console.WriteLine("{0}", Compare(destSegment, data));
            }

            input = new byte[dest.Length];
            Array.Copy(dest, input, dest.Length);
            output = new byte[dest.Length];

            fixed(byte* keyPointer = key) 
            fixed(byte* inputPointer = input)
            fixed(byte* outputPointer = output)
            {
                var length = Aes.AesDecrypt(keyPointer, key.Length, inputPointer, dest.Length, outputPointer);
                var data = new ArraySegment<byte>(output, 0, length);
                Console.WriteLine("{0}", Compare(originSegment, data));
            }

            Console.WriteLine("Hello World!");
        }
    }
}

跑通測試之後, 就可以整合到系統裡面去, 把託管實現給替換掉. 一次可以把多餘的AllocArray, 和加速同時完成.

C++版本的AES ECB加密使用了OpenSSL庫, 好處是工業級實現, 而且還有可能會有AES-NI加速, Windows上面只需要通過vcpkg就可以方便的移植過來, Linux上面本身就有這個庫.

大部分C#程式碼都可以跑得非常快, 一般情況下是不需要進行這種極端優化. 但是某遊戲伺服器是一個比較特殊的伺服器, 其伺服器只有一個程序, 一個程序內需要跑IO密集, 計算密集(加解密,物理,戰鬥等), 還要承擔GC的負擔, 所以才採用了這種優化方式.

參考:

  1. P/Invoke
  2. P/Invoke開銷
  3. OpenSSL AES
  4. AES-NI Performance
  5. vcpkg

通過P/Invoke加速C#