Snowflake算法

snowflake是twitter为了搬移数据库从Mysql到cassandra生成可排序主键而创造的极其简单高效的分布式主键生成算法。

2014-01-30

21 次浏览

概述

snowflake是twitter为了搬移数据库从Mysql到cassandra生成可排序主键而创造的极其简单高效的分布式主键生成算法。

研究价值

研究snowflake的初衷是我们想创造一个服务域内各自独立运行也能生成唯一ID的生成器,好处就是我们可以不再需要将有业务意义的ID通过数据库来提现,同时又能很好的兼容数据库的排序功能。

我们知道如果使用数据库的自增ID,有一个严重的问题,就是会暴露业务,例如我们可以很轻易的就知道当日的交易数量或者交易频率,通过/transaction/12,/transaction/103等方式。那么UUID呢,UUID本身是一个很好的解决方案,他会生成几十年内都宇宙唯一的ID,除非来自平行宇宙的另一个UUID才有会产生碰撞,或者几十年后会有极其微弱的概率会产生碰撞,但是UUID是字符串形式的,这对于我们采用了mysql这样的数据库来说,用UUID就是一种不友好无法连续排序的主键,由于mysql的存储数据是根据主键来索引的,所以在有一定数据量的时候(通常是在20万左右),mysql就会因为移动太多的数据而导致插入太慢,当达到更大的数据量时(大约是200万-2000万左右),就已经让人无法忍受了。具体的性能可以自行谷歌百度必应。

算法描述

D=毫秒级时间41位+机器ID的10位+毫秒内序列12位,这就是算法的核心公式,我们把位数切开来看:

0---0000000000 0000000000 0000000000 0000000000 0 --- 00000 ---00000 ---0000000000 00

在上面的字符串中,第一位为未使用(实际上也可作为long的符号位),接下来的41位为毫秒级时间,然后5位datacenter标识位,5位机器ID(并不算标识符,实际是为线程标识),然后12位该毫秒内的当前毫秒内的计数,加起来刚好64位,为一个Long型。
这样的好处是,整体上按照时间自增排序,并且整个分布式系统内不会产生ID碰撞(由datacenter和机器ID作区分),并且效率较高,经测试,snowflake每秒能够产生26万ID左右,完全满足需要。

可替代方案

分布式环境下的唯一主键生成策略我们可以归纳为几种方式:

  • 使用UUID/GUID
  • 使用时间戳加上随机码的方式,这种方式可以完全的放在client端,每个生成器单独使用,是一种接近UUID的方式。但是这个方式有一定的几率会出现碰撞。
  • The most significant 32 bits: Timestamp, the generation time of the ID.
  • The least significant 32 bits: 32-bits of randomness, generated anew for each ID.
  • 使用改进的更接近UUID的方式,除了时间戳的,再加上机器的标识,以及该机器的自增码。这样每个生成器也可以单独使用。
  • The most significant 40 or so bits: A timestamp; the generation time of the ID.
  • The next 14 or so bits: A per-generator counter, which each generator increments by one for each new ID generated. This ensures that IDs generated at the same moment (same timestamps) do not overlap.
  • The last 10 or so bits: A unique value for each generator. Using this, we don't need to do any synchronization between generators (which is extremely hard), as all generators produce non-overlapping IDs because of this value.
  • 使用Snowflake服务。这个服务可以在github的twitter/snowflake下载。
  • Networked service, i.e. you make a network call to get a unique ID;
  • which produces 64 bit unique IDs that are ordered by generation time;
  • and the service is highly scalable and (potentially) highly available; each instance can generate many thousand IDs per second, and you can run multiple instances on your LAN/WAN;
  • written in Scala, runs on the JVM.

关于测试

测试的时候,需要注意的是,如果你需要一个类似于Set的全局容器进行测试,一定要保证这个全局容器是线程安全的(Thread Safe),这样才能保证同时插入这个容器的时候不会出现问题。

Scale代码

Twitter的原版实现

/** Copyright 2010-2012 Twitter, Inc.*/
package com.twitter.service.snowflake

import com.twitter.ostrich.stats.Stats
import com.twitter.service.snowflake.gen._
import java.util.Random
import com.twitter.logging.Logger

/**
 * An object that generates IDs.
 * This is broken into a separate class in case
 * we ever want to support multiple worker threads
 * per process
 */
class IdWorker(val workerId: Long, val datacenterId: Long, private val reporter: Reporter, var sequence: Long = 0L)
extends Snowflake.Iface {
  private[this] def genCounter(agent: String) = {
    Stats.incr("ids_generated")
    Stats.incr("ids_generated_%s".format(agent))
  }
  private[this] val exceptionCounter = Stats.getCounter("exceptions")
  private[this] val log = Logger.get
  private[this] val rand = new Random

  val twepoch = 1288834974657L

  private[this] val workerIdBits = 5L
  private[this] val datacenterIdBits = 5L
  private[this] val maxWorkerId = -1L ^ (-1L << workerIdBits)
  private[this] val maxDatacenterId = -1L ^ (-1L << datacenterIdBits)
  private[this] val sequenceBits = 12L

  private[this] val workerIdShift = sequenceBits
  private[this] val datacenterIdShift = sequenceBits + workerIdBits
  private[this] val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits
  private[this] val sequenceMask = -1L ^ (-1L << sequenceBits)

  private[this] var lastTimestamp = -1L

  // sanity check for workerId
  if (workerId > maxWorkerId || workerId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId))
  }

  if (datacenterId > maxDatacenterId || datacenterId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId))
  }

  log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
    timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)

  def get_id(useragent: String): Long = {
    if (!validUseragent(useragent)) {
      exceptionCounter.incr(1)
      throw new InvalidUserAgentError
    }

    val id = nextId()
    genCounter(useragent)

    reporter.report(new AuditLogEntry(id, useragent, rand.nextLong))
    id
  }

  def get_worker_id(): Long = workerId
  def get_datacenter_id(): Long = datacenterId
  def get_timestamp() = System.currentTimeMillis

  protected[snowflake] def nextId(): Long = synchronized {
    var timestamp = timeGen()

    if (timestamp < lastTimestamp) {
      exceptionCounter.incr(1)
      log.error("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
      throw new InvalidSystemClock("Clock moved backwards.  Refusing to generate id for %d milliseconds".format(
        lastTimestamp - timestamp))
    }

    if (lastTimestamp == timestamp) {
      sequence = (sequence + 1) & sequenceMask
      if (sequence == 0) {
        timestamp = tilNextMillis(lastTimestamp)
      }
    } else {
      sequence = 0
    }

    lastTimestamp = timestamp
    ((timestamp - twepoch) << timestampLeftShift) |
      (datacenterId << datacenterIdShift) |
      (workerId << workerIdShift) |
      sequence
  }

  protected def tilNextMillis(lastTimestamp: Long): Long = {
    var timestamp = timeGen()
    while (timestamp <= lastTimestamp) {
      timestamp = timeGen()
    }
    timestamp
  }

  protected def timeGen(): Long = System.currentTimeMillis()

  val AgentParser = """([a-zA-Z][a-zA-Z\-0-9]*)""".r

  def validUseragent(useragent: String): Boolean = useragent match {
    case AgentParser(_) => true
    case _ => false
  }
}

Java源码

由于项目用的是Java,这里我把scale改为了Java,去除了datacenter编号。

package com.omoney.core.common.generator.flake;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Id 生成器
 * <p>
 * ID发生器的主要目的是为了每个不同名字的发生器都能生成全网唯一的ID
 * <p>
 * id is composed of:
 * time - 41 bits (millisecond precision w/ a custom epoch gives us 69 years)
 * configured machine id - 10 bits - gives us up to 1024 machines
 * sequence number - 12 bits - rolls over every 4096 per machine (with protection to avoid rollover in the same ms)
 */
public class IdWorker {
    private Logger logger = LoggerFactory.getLogger(IdWorker.class);
    private final static long epoch = 1429873810601L;
    private final long workerIdBits = 10L; // worker id 位数
    private final long sequenceBits = 12L; //  本worker的计数器位数

    private final long maxWorkerId = ~(-1L << workerIdBits);
    private final long workerIdShift = sequenceBits;
    private final long timestampShift = sequenceBits + workerIdBits;
    private final long sequenceStep = ~(-1L << sequenceBits);  //同一毫秒内每次计数增长4096

    private long workerId = 0L;
    private long sequence = 0L;
    private long lastTimestamp = -1L;

    /**
     * @param workerId range [0 - 1023]
     */
    public IdWorker(final long workerId) {
        if (workerId > this.maxWorkerId || workerId < 0) {
            throw new IllegalArgumentException(String.format(
                    "worker Id can't be greater than %d or less than 0",
                    this.maxWorkerId));
        }
        this.workerId = workerId;
    }


    protected synchronized long nextId() {
        // 获取当前时间
        long timestamp = System.currentTimeMillis();

        // 如果当前时间比上次的时间要小则说明时间计数出了问题
        if (timestamp < lastTimestamp) {
            logger.error("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
            throw new RuntimeException(String.format("Clock moved backwards.  Refusing to generate id for %d milliseconds",
                    lastTimestamp - timestamp));
        }

        // 同一个毫秒级的时间戳上,那么就需要进行计数
        if (lastTimestamp == timestamp) {
            sequence = (sequence + 1) & sequenceStep;
            if (sequence == 0) {
                // 等待下一个毫秒的到来
                timestamp = waitlNextMillis(lastTimestamp);
            }
        } else {
            sequence = 0;
        }

        lastTimestamp = timestamp;
        return ((timestamp - epoch) << timestampShift) |
                (workerId << workerIdShift) | sequence;
    }

    protected long waitlNextMillis(long lastTimestamp) {
        long timestamp = System.currentTimeMillis();
        while (timestamp <= lastTimestamp) {
            timestamp = System.currentTimeMillis();
        }
        return timestamp;
    }

    public long getWorkerId() {
        return workerId;
    }
}