1. 程式人生 > >雪花演算法(snowflake)

雪花演算法(snowflake)

雪花演算法簡單描述:
+ 最高位是符號位,始終為0,不可用。
+ 41位的時間序列,精確到毫秒級,41位的長度可以使用69年。時間位還有一個很重要的作用是可以根據時間進行排序。
+ 10位的機器標識,10位的長度最多支援部署1024個節點。
+ 12位的計數序列號,序列號即一系列的自增id,可以支援同一節點同一毫秒生成多個ID序號,12位的計數序列號支援每個節點每毫秒產生4096個ID序號。

看的出來,這個演算法很簡潔也很簡單,但依舊是一個很好的ID生成策略。其中,10位器識別符號一般是5位IDC+5位machine編號,唯一確定一臺機器。

我不知道Twitter是如何確定10位機器標識和12位序列號的,不過,很明顯,這是可以改變的。

public class SnowFlakeGenerator {

    public static class Factory {
        /**
         * 每一部分佔用位數的預設值
         */
        private final static int DEFAULT_MACHINE_BIT_NUM = 5;   //機器標識佔用的位數
        private final static int DEFAULT_IDC_BIT_NUM = 5;//資料中心佔用的位數

        private int machineBitNum;
        private
int idcBitNum; public Factory() { this.idcBitNum = DEFAULT_IDC_BIT_NUM; this.machineBitNum = DEFAULT_MACHINE_BIT_NUM; } public Factory(int machineBitNum, int idcBitNum) { this.idcBitNum = idcBitNum; this.machineBitNum = machineBitNum; } public
SnowFlakeGenerator create(long idcId, long machineId) { return new SnowFlakeGenerator(this.idcBitNum, this.machineBitNum, idcId, machineId); } } /** * 起始的時間戳 * 作者寫程式碼時的時間戳 */ private final static long START_STAMP = 1508143349995L; /** * 可分配的位數 */ private final static int REMAIN_BIT_NUM = 22; /** * idc編號 */ private long idcId; /** * 機器編號 */ private long machineId; /** * 當前序列號 */ private long sequence = 0L; /** * 上次最新時間戳 */ private long lastStamp = -1L; /** * idc偏移量:一次計算出,避免重複計算 */ private int idcBitLeftOffset; /** * 機器id偏移量:一次計算出,避免重複計算 */ private int machineBitLeftOffset; /** * 時間戳偏移量:一次計算出,避免重複計算 */ private int timestampBitLeftOffset; /** * 最大序列值:一次計算出,避免重複計算 */ private int maxSequenceValue; private SnowFlakeGenerator(int idcBitNum, int machineBitNum, long idcId, long machineId) { int sequenceBitNum = REMAIN_BIT_NUM - idcBitNum - machineBitNum; if (idcBitNum <= 0 || machineBitNum <= 0 || sequenceBitNum <= 0) { throw new IllegalArgumentException("error bit number"); } this.maxSequenceValue = ~(-1 << sequenceBitNum); machineBitLeftOffset = sequenceBitNum; idcBitLeftOffset = idcBitNum + sequenceBitNum; timestampBitLeftOffset = idcBitNum + machineBitNum + sequenceBitNum; this.idcId = idcId; this.machineId = machineId; } /** * 產生下一個ID */ public synchronized long nextId() { long currentStamp = getTimeMill(); if (currentStamp < lastStamp) { throw new RuntimeException(String.format("Clock moved backwards. Refusing to generate id for %d milliseconds", lastStamp - currentStamp)); } //新的毫秒,序列從0開始,否則序列自增 if (currentStamp == lastStamp) { sequence = (sequence + 1) & this.maxSequenceValue; if (sequence == 0L) { //Twitter原始碼中的邏輯是迴圈,直到下一個毫秒 lastStamp = tilNextMillis(); // throw new IllegalStateException("sequence over flow"); } } else { sequence = 0L; } lastStamp = currentStamp; return (currentStamp - START_STAMP) << timestampBitLeftOffset | idcId << idcBitLeftOffset | machineId << machineBitLeftOffset | sequence; } private long getTimeMill() { return System.currentTimeMillis(); } private long tilNextMillis() { long timestamp = getTimeMill(); while (timestamp <= lastStamp) { timestamp = getTimeMill(); } return timestamp; } }

也附上Twitter的scala版:

/** Copyright 2010-2012 Twitter, Inc.*/
package com.twitter.service.snowflake

import com.twitter.ostrich.stats.Stats
import com.twitter.service.snowflake.gen._
import java.util.Random
import com.twitter.logging.Logger

/**
 * An object that generates IDs.
 * This is broken into a separate class in case
 * we ever want to support multiple worker threads
 * per process
 */
class IdWorker(val workerId: Long, val datacenterId: Long, private val reporter: Reporter, var sequence: Long = 0L)
extends Snowflake.Iface {
  private[this] def genCounter(agent: String) = {
    Stats.incr("ids_generated")
    Stats.incr("ids_generated_%s".format(agent))
  }
  private[this] val exceptionCounter = Stats.getCounter("exceptions")
  private[this] val log = Logger.get
  private[this] val rand = new Random

  val twepoch = 1288834974657L

  private[this] val workerIdBits = 5L
  private[this] val datacenterIdBits = 5L
  private[this] val maxWorkerId = -1L ^ (-1L << workerIdBits)
  private[this] val maxDatacenterId = -1L ^ (-1L << datacenterIdBits)
  private[this] val sequenceBits = 12L

  private[this] val workerIdShift = sequenceBits
  private[this] val datacenterIdShift = sequenceBits + workerIdBits
  private[this] val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits
  private[this] val sequenceMask = -1L ^ (-1L << sequenceBits)

  private[this] var lastTimestamp = -1L

  // sanity check for workerId
  if (workerId > maxWorkerId || workerId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId))
  }

  if (datacenterId > maxDatacenterId || datacenterId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId))
  }

  log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
    timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)

  def get_id(useragent: String): Long = {
    if (!validUseragent(useragent)) {
      exceptionCounter.incr(1)
      throw new InvalidUserAgentError
    }

    val id = nextId()
    genCounter(useragent)

    reporter.report(new AuditLogEntry(id, useragent, rand.nextLong))
    id
  }

  def get_worker_id(): Long = workerId
  def get_datacenter_id(): Long = datacenterId
  def get_timestamp() = System.currentTimeMillis

  protected[snowflake] def nextId(): Long = synchronized {
    var timestamp = timeGen()

    if (timestamp < lastTimestamp) {
      exceptionCounter.incr(1)
      log.error("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
      throw new InvalidSystemClock("Clock moved backwards.  Refusing to generate id for %d milliseconds".format(
        lastTimestamp - timestamp))
    }

    if (lastTimestamp == timestamp) {
      sequence = (sequence + 1) & sequenceMask
      if (sequence == 0) {
        timestamp = tilNextMillis(lastTimestamp)
      }
    } else {
      sequence = 0
    }

    lastTimestamp = timestamp
    ((timestamp - twepoch) << timestampLeftShift) |
      (datacenterId << datacenterIdShift) |
      (workerId << workerIdShift) | 
      sequence
  }

  protected def tilNextMillis(lastTimestamp: Long): Long = {
    var timestamp = timeGen()
    while (timestamp <= lastTimestamp) {
      timestamp = timeGen()
    }
    timestamp
  }

  protected def timeGen(): Long = System.currentTimeMillis()

  val AgentParser = """([a-zA-Z][a-zA-Z\-0-9]*)""".r

  def validUseragent(useragent: String): Boolean = useragent match {
    case AgentParser(_) => true
    case _ => false
  }
}