1. 程式人生 > >動手實現一個 LRU cache

動手實現一個 LRU cache

前言

LRU 是 LeastRecentlyUsed 的簡寫,字面意思則是 最近最少使用。通常用於快取的淘汰策略實現,由於快取的記憶體非常寶貴,所以需要根據某種規則來剔除資料保證記憶體不被撐滿。如常用的 Redis 就有以下幾種策略:

策略 描述
volatile-lru 從已設定過期時間的資料集中挑選最近最少使用的資料淘汰
volatile-ttl 從已設定過期時間的資料集中挑選將要過期的資料淘汰
volatile-random 從已設定過期時間的資料集中任意選擇資料淘汰
allkeys-lru 從所有資料集中挑選最近最少使用的資料淘汰
allkeys-random 從所有資料集中任意選擇資料進行淘汰
no-envicition 禁止驅逐資料

實現一

之前也有接觸過一道面試題,大概需求是:

  • 實現一個 LRU 快取,當快取資料達到 N 之後需要淘汰掉最近最少使用的資料。
  • N 小時之內沒有被訪問的資料也需要淘汰掉。

以下是我的實現:

  1. public class LRUAbstractMap extends java.util.AbstractMap {
  2. private final static Logger LOGGER = LoggerFactory.getLogger(LRUAbstractMap.class);
  3. /**
  4.     * 檢查是否超期執行緒
  5.     */
  6. private ExecutorService checkTimePool ;
  7. /**
  8.     * map 最大size
  9.     */
  10. private final static int MAX_SIZE = 1024 ;
  11. private final static ArrayBlockingQueue<Node> QUEUE = new ArrayBlockingQueue<>(MAX_SIZE) ;
  12. /**
  13.     * 預設大小
  14.     */
  15. private final static int DEFAULT_ARRAY_SIZE =1024 ;
  16. /**
  17.     * 陣列長度
  18.     */
  19. private int arraySize ;
  20. /**
  21.     * 陣列
  22.     */
  23. private Object[] arrays ;
  24. /**
  25.     * 判斷是否停止 flag
  26.     */
  27. private volatile boolean flag = true ;
  28. /**
  29.     * 超時時間
  30.     */
  31. private final static Long EXPIRE_TIME = 60 * 60 * 1000L ;
  32. /**
  33.     * 整個 Map 的大小
  34.     */
  35. private volatile AtomicInteger size  ;
  36. public LRUAbstractMap() {
  37.        arraySize = DEFAULT_ARRAY_SIZE;
  38.        arrays = new Object[arraySize] ;
  39. //開啟一個執行緒檢查最先放入佇列的值是否超期
  40.        executeCheckTime();
  41. }
  42. /**
  43.     * 開啟一個執行緒檢查最先放入佇列的值是否超期 設定為守護執行緒
  44.     */
  45. private void executeCheckTime() {
  46. ThreadFactory namedThreadFactory = new ThreadFactoryBuilder()
  47. .setNameFormat("check-thread-%d")
  48. .setDaemon(true)
  49. .build();
  50.        checkTimePool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS,
  51. new ArrayBlockingQueue<>(1),namedThreadFactory,new ThreadPoolExecutor.AbortPolicy());
  52.        checkTimePool.execute(new CheckTimeThread()) ;
  53. }
  54. @Override
  55. public Set<Entry> entrySet() {
  56. return super.keySet();
  57. }
  58. @Override
  59. public Object put(Object key, Object value) {
  60. int hash = hash(key);
  61. int index = hash % arraySize ;
  62. Node currentNode = (Node) arrays[index] ;
  63. if (currentNode == null){
  64.            arrays[index] = new Node(null,null, key, value);
  65. //寫入佇列
  66.            QUEUE.offer((Node) arrays[index]) ;
  67.            sizeUp();
  68. }else {
  69. Node cNode = currentNode ;
  70. Node nNode = cNode ;
  71. //存在就覆蓋
  72. if (nNode.key == key){
  73.                cNode.val = value ;
  74. }
  75. while (nNode.next != null){
  76. //key 存在 就覆蓋 簡單判斷
  77. if (nNode.key == key){
  78.                    nNode.val = value ;
  79. break ;
  80. }else {
  81. //不存在就新增連結串列
  82.                    sizeUp();
  83. Node node = new Node(nNode,null,key,value) ;
  84. //寫入佇列
  85.                    QUEUE.offer(currentNode) ;
  86.                    cNode.next = node ;
  87. }
  88.                nNode = nNode.next ;
  89. }
  90. }
  91. return null ;
  92. }
  93. @Override
  94. public Object get(Object key) {
  95. int hash = hash(key) ;
  96. int index = hash % arraySize ;
  97. Node currentNode = (Node) arrays[index] ;
  98. if (currentNode == null){
  99. return null ;
  100. }
  101. if (currentNode.next == null){
  102. //更新時間
  103.            currentNode.setUpdateTime(System.currentTimeMillis());
  104. //沒有衝突
  105. return currentNode ;
  106. }
  107. Node nNode = currentNode ;
  108. while (nNode.next != null){
  109. if (nNode.key == key){
  110. //更新時間
  111.                currentNode.setUpdateTime(System.currentTimeMillis());
  112. return nNode ;
  113. }
  114.            nNode = nNode.next ;
  115. }
  116. return super.get(key);
  117. }
  118. @Override
  119. public Object remove(Object key) {
  120. int hash = hash(key) ;
  121. int index = hash % arraySize ;
  122. Node currentNode = (Node) arrays[index] ;
  123. if (currentNode == null){
  124. return null ;
  125. }
  126. if (currentNode.key == key){
  127.            sizeDown();
  128.            arrays[index] = null ;
  129. //移除佇列
  130.            QUEUE.poll();
  131. return currentNode ;
  132. }
  133. Node nNode = currentNode ;
  134. while (nNode.next != null){
  135. if (nNode.key == key){
  136.                sizeDown();
  137. //在連結串列中找到了 把上一個節點的 next 指向當前節點的下一個節點
  138.                nNode.pre.next = nNode.next ;
  139.                nNode = null ;
  140. //移除佇列
  141.                QUEUE.poll();
  142. return nNode;
  143. }
  144.            nNode = nNode.next ;
  145. }
  146. return super.remove(key);
  147. }
  148. /**
  149.     * 增加size
  150.     */
  151. private void sizeUp(){
  152. //在put值時候認為裡邊已經有資料了
  153.        flag = true ;
  154. if (size == null){
  155.            size = new AtomicInteger() ;
  156. }
  157. int size = this.size.incrementAndGet();
  158. if (size >= MAX_SIZE) {
  159. //找到佇列頭的資料
  160. Node node = QUEUE.poll() ;
  161. if (node == null){
  162. throw new RuntimeException("data error") ;
  163. }
  164. //移除該 key
  165. Object key = node.key ;
  166.            remove(key) ;
  167.            lruCallback() ;
  168. }
  169. }
  170. /**
  171.     * 數量減小
  172.     */
  173. private void sizeDown(){
  174. if (QUEUE.size() == 0){
  175.            flag = false ;
  176. }
  177. this.size.decrementAndGet() ;
  178. }
  179. @Override
  180. public int size() {
  181. return size.get() ;
  182. }
  183. /**
  184.     * 連結串列
  185.     */
  186. private class Node{
  187. private Node next ;
  188. private Node pre ;
  189. private Object key ;
  190. private Object val ;
  191. private Long updateTime ;
  192. public Node(Node pre,Node next, Object key, Object val) {
  193. this.pre = pre ;
  194. this.next = next;
  195. this.key = key;
  196. this.val = val;
  197. this.updateTime = System.currentTimeMillis() ;
  198. }
  199. public void setUpdateTime(Long updateTime) {
  200. this.updateTime = updateTime;
  201. }
  202. public Long getUpdateTime() {
  203. return updateTime;
  204. }
  205. @Override
  206. public String toString() {
  207. return "Node{" +
  208. "key=" + key +
  209. ", val=" + val +
  210. '}';
  211. }
  212. }
  213. /**
  214.     * copy HashMap 的 hash 實現
  215.     * @param key
  216.     * @return
  217.     */
  218. public int hash(Object key) {
  219. int h;
  220. return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
  221. }
  222. private void lruCallback(){
  223.        LOGGER.debug("lruCallback");
  224. }
  225. private class CheckTimeThread implements Runnable{
  226. @Override
  227. public void run() {
  228. while (flag){
  229. try {
  230. Node node = QUEUE.poll();
  231. if (node == null){
  232. continue ;
  233. }
  234. Long updateTime = node.getUpdateTime() ;
  235. if ((updateTime - System.currentTimeMillis()) >= EXPIRE_TIME){
  236.                        remove(node.key) ;
  237. }
  238. } catch (Exception e) {
  239.                    LOGGER.error("InterruptedException");
  240. }
  241. }
  242. }
  243. }
  244. }

感興趣的朋友可以直接從:

https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUAbstractMap.java

下載程式碼本地執行。

程式碼看著比較多,其實實現的思路還是比較簡單:

  • 採用了與 HashMap 一樣的儲存資料方式,只是自己手動實現了一個簡易版。
  • 內部採用了一個佇列來儲存每次寫入的資料。
  • 寫入的時候判斷快取是否大於了閾值 N,如果滿足則根據佇列的 FIFO 特性將佇列頭的資料刪除。因為佇列頭的資料肯定是最先放進去的。
  • 再開啟了一個守護執行緒用於判斷最先放進去的資料是否超期(因為就算超期也是最先放進去的資料最有可能滿足超期條件。)
  • 設定為守護執行緒可以更好的表明其目的(最壞的情況下,如果是一個使用者執行緒最終有可能導致程式不能正常退出,因為該執行緒一直在執行,守護執行緒則不會有這個情況。)

以上程式碼大體功能滿足了,但是有一個致命問題。

就是最近最少使用沒有滿足,刪除的資料都是最先放入的資料。

不過其中的 putget 流程算是一個簡易的 HashMap 實現,可以對 HashMap 加深一些理解。

實現二

因此如何來實現一個完整的 LRU 快取呢,這次不考慮過期時間的問題。

其實從上一個實現也能想到一些思路:

  • 要記錄最近最少使用,那至少需要一個有序的集合來保證寫入的順序。
  • 在使用了資料之後能夠更新它的順序。

基於以上兩點很容易想到一個常用的資料結構:連結串列

  1. 每次寫入資料時將資料放入連結串列頭結點。
  2. 使用資料時候將資料移動到頭結點
  3. 快取數量超過閾值時移除連結串列尾部資料。

因此有了以下實現:

  1. public class LRUMap<K, V> {
  2. private final Map<K, V> cacheMap = new HashMap<>();
  3. /**
  4.     * 最大快取大小
  5.     */
  6. private int cacheSize;
  7. /**
  8.     * 節點大小
  9.     */
  10. private int nodeCount;
  11. /**
  12.     * 頭結點
  13.     */
  14. private Node<K, V> header;
  15. /**
  16.     * 尾結點
  17.     */
  18. private Node<K, V> tailer;
  19. public LRUMap(int cacheSize) {
  20. this.cacheSize = cacheSize;
  21. //頭結點的下一個結點為空
  22.        header = new Node<>();
  23.        header.next = null;
  24. //尾結點的上一個結點為空
  25.        tailer = new Node<>();
  26.        tailer.tail = null;
  27. //雙向連結串列 頭結點的上結點指向尾結點
  28.        header.tail = tailer;
  29. //尾結點的下結點指向頭結點
  30.        tailer.next = header;
  31. }
  32. public void put(K key, V value) {
  33.        cacheMap.put(key, value);
  34. //雙向連結串列中新增結點
  35.        addNode(key, value);
  36. }
  37. public V get(K key){
  38. Node<K, V> node = getNode(key);
  39. //移動到頭結點
  40.        moveToHead(node) ;
  41. return cacheMap.get(key);
  42. }
  43. private void moveToHead(Node<K,V> node){
  44. //如果是最後的一個節點
  45. if (node.tail == null){
  46.            node.next.tail = null ;
  47.            tailer = node.next ;
  48.            nodeCount -- ;
  49. }
  50. //如果是本來就是頭節點 不作處理
  51. if (node.next == null){
  52. return ;
  53. }
  54. //如果處於中間節點
  55. if (node.tail != null && node.next != null){
  56. //它的上一節點指向它的下一節點 也就刪除當前節點
  57.            node.tail.next = node.next ;
  58.            nodeCount -- ;
  59. }
  60. //最後在頭部增加當前節點
  61. //注意這裡需要重新 new 一個物件,不然原本的node 還有著下面的引用,會造成記憶體溢位。
  62.        node = new Node<>(node.getKey(),node.getValue()) ;
  63.        addHead(node) ;
  64. }
  65. /**
  66.     * 連結串列查詢 效率較低
  67.     * @param key
  68.     * @return
  69.     */
  70. private Node<K,V> getNode(K key){
  71. Node<K,V> node = tailer ;
  72. while (node != null){
  73. if (node.getKey().equals(key)){
  74. return node ;
  75. }
  76.            node = node.next ;
  77. }
  78. return null ;
  79. }
  80. /**
  81.     * 寫入頭結點
  82.     * @param key
  83.     * @param value
  84.     */
  85. private void addNode(K key, V value) {
  86. Node<K, V> node = new Node<>(key, value);
  87. //容量滿了刪除最後一個
  88. if (cacheSize == nodeCount) {
  89. //刪除尾結點
  90.            delTail();
  91. }
  92. //寫入頭結點
  93.        addHead(node);
  94. }
  95. /**
  96.     * 新增頭結點
  97.     *
  98.     * @param node
  99.     */
  100. private void addHead(Node<K, V> node) {
  101. //寫入頭結點
  102.        header.next = node;
  103.        node.tail = header;
  104.        header = node;
  105.        nodeCount++;
  106. //如果寫入的資料大於2個 就將初始化的頭尾結點刪除
  107. if (nodeCount == 2) {
  108.            tailer.next.next.tail = null;
  109.            tailer = tailer.next.next;
  110. }
  111. }
  112. private void delTail() {
  113. //把尾結點從快取中刪除
  114.        cacheMap.remove(tailer.getKey());
  115. //刪除尾結點
  116.        tailer.next.tail = null;
  117.        tailer = tailer.next;
  118.        nodeCount--;
  119. }
  120. private class Node<K, V> {
  121. private K key;
  122. private V value;
  123. Node<K, V> tail;
  124. Node<K, V> next;
  125. public Node(K key, V value) {
  126. this.key = key;
  127. this.value = value;
  128. }
  129. public Node() {
  130. }
  131. public K getKey() {
  132. return key;
  133. }
  134. public void setKey(K key) {
  135. this.key = key;
  136. }
  137. public V getValue() {
  138. return value;
  139. }
  140. public void setValue(V value) {
  141. this.value = value;
  142. }
  143. }
  144. @Override
  145. public String toString() {
  146. StringBuilder sb = new StringBuilder() ;
  147. Node<K,V> node = tailer ;
  148. while (node != null){
  149.            sb.append(node.getKey()).append(":")
  150. .append(node.getValue())
  151. .append("-->") ;
  152.            node = node.next ;
  153. }
  154. return sb.toString();
  155. }
  156. }

原始碼: https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUMap.java

實際效果,寫入時:

  1. @Test
  2. public void put() throws Exception {
  3. LRUMap<String,Integer> lruMap = new LRUMap(3) ;
  4.        lruMap.put("1",1) ;
  5.        lruMap.put("2",2) ;
  6.        lruMap.put("3",3) ;
  7. System.out.println(lruMap.toString());
  8.        lruMap.put("4",4) ;
  9. System.out.println(lruMap.toString());
  10.        lruMap.put("5",5) ;
  11. System.out.println(