最近一段時(shí)間在看一本書《Java并發(fā)編程的藝術(shù)》,在P164講到了關(guān)于ConcurrentLinkedQueue的源碼分析徐许,但是這部分源碼非常復(fù)雜,于是我又順手看了一下IDEA的Java源碼卒蘸,發(fā)現(xiàn)在Java8中雌隅,該部分的源碼已經(jīng)被更新過了,正好讀一讀順帶做個(gè)筆記缸沃。
基本介紹
ConcurrentLinkedQueue是一個(gè)列表實(shí)現(xiàn)恰起,包括一個(gè)head和tail引用,該類的初始化過程中趾牧,頭尾引用都被初始化成一個(gè)空的Node检盼,下面我們可以看到相關(guān)代碼:
public class ConcurrentLinkedQueue<E> extends AbstractQueue<E>
implements Queue<E>, java.io.Serializable {
private static class Node<E> {
volatile E item;
volatile Node<E> next;
}
private transient volatile Node<E> head;
private transient volatile Node<E> tail;
public ConcurrentLinkedQueue() {
head = tail = new Node<E>(null);
}
}
入隊(duì)流程
單線程下的入隊(duì)流程為:
- 將新節(jié)點(diǎn)加入到tail引用的next中
- 將新節(jié)點(diǎn)賦值給tail引用
但是在多線程環(huán)境中,需要保障其他線程入隊(duì)和出隊(duì)不受影響翘单,ConcurrentLinkedQueue由CAS算法實(shí)現(xiàn)了無鎖入隊(duì)吨枉,下面是加入節(jié)點(diǎn)的關(guān)鍵代碼:
public boolean offer(E e) {
checkNotNull(e);
final Node<E> newNode = new Node<E>(e);
// 循環(huán)開始,p和t都指向tail哄芜,q指向tail的next
for (Node<E> t = tail, p = t;;) {
Node<E> q = p.next;
if (q == null) {
// q為null代表目前tail后面沒有其他線程插入的節(jié)點(diǎn)貌亭,即p確實(shí)是最后的節(jié)點(diǎn)
if (p.casNext(null, newNode)) {
// 這里casNext函數(shù)的作用是當(dāng)p的next節(jié)點(diǎn)為null時(shí),用newNode更新p的next節(jié)點(diǎn)认臊,更新成功返回true
// 如果casNext更新成功圃庭,證明newNode已經(jīng)成功插入到隊(duì)尾
if (p != t)
// 這一步判斷表明,t即tail已經(jīng)不是真正的隊(duì)尾引用,這是減少cas操作的一步優(yōu)化
// 這里casTail函數(shù)的作用是當(dāng)tail與t相等時(shí)剧腻,用newNode更新tail拘央,在這里CAS失敗也沒有關(guān)系
casTail(t, newNode);
return true;
}
// 如果casNext更新失敗,則重新將p的next賦值給q
}
else if (p == q)
// 當(dāng)p==q只有一種情況恕酸,即p==p.next堪滨,在這種情況下就表明當(dāng)前節(jié)點(diǎn)已經(jīng)離隊(duì),因?yàn)樵诔鲫?duì)操作之后蕊温,ConcurrentLinkedQueue會(huì)將出隊(duì)節(jié)點(diǎn)的next設(shè)為它本身
// 在遇到當(dāng)前節(jié)點(diǎn)已經(jīng)是出隊(duì)節(jié)點(diǎn)的情況下袱箱,表明當(dāng)前節(jié)點(diǎn)已經(jīng)在head之前,因此根據(jù)如下邏輯進(jìn)行更新當(dāng)前節(jié)點(diǎn):1义矛、如果tail已經(jīng)更新发笔,那么將當(dāng)前節(jié)點(diǎn)設(shè)為tail;2凉翻、否則了讨,將當(dāng)前節(jié)點(diǎn)設(shè)為head,因?yàn)椴荒鼙WCtail指向的節(jié)點(diǎn)是否已經(jīng)離隊(duì)
p = (t != (t = tail)) ? t : head;
else
// 當(dāng)tail更新且p不在tail時(shí)制轰,用tail更新p前计,否則用q更新p
p = (p != t && t != (t = tail)) ? t : q;
}
}
如果覺得上述方法過于復(fù)雜,我們可以用一種更簡(jiǎn)單的方案來進(jìn)行結(jié)果相同的操作:
public boolean offer(E e) {
checkNotNull(e);
final Node<E> newNode = new Node<E>(e);
for (; ; ) {
Node<E> t = tail;
if (t.casNext(null, newNode)) {
// 參照單線程的入隊(duì)流程垃杖,casNext成功表明newNode已經(jīng)成功插入到了隊(duì)列里
// 如果casTail失敗了也沒有關(guān)系男杈,失敗了證明有其他的線程在進(jìn)行casTail,至少有一根線程可以成功
casTail(t, newNode);
return true;
}
}
}
而在JDK源碼中调俘,加入了一步優(yōu)化伶棒,這步優(yōu)化是:在插入一個(gè)新節(jié)點(diǎn)時(shí),不著急將tail指向這個(gè)新節(jié)點(diǎn)彩库,而是在插入第二個(gè)新節(jié)點(diǎn)的時(shí)候肤无,才對(duì)tail進(jìn)行cas操作。
這樣做會(huì)導(dǎo)致兩個(gè)問題:
- tail并不在保持原有的一定指向隊(duì)尾的性質(zhì)骇钦;
- 從tail開始需要進(jìn)過幾步查找next才能尋找到真正的隊(duì)尾宛渐;
但是這樣做有一個(gè)好處:減少了至少一半的cas操作,雖然增加了普通的賦值操作眯搭,但是在多線程情況下cas操作的耗時(shí)要遠(yuǎn)遠(yuǎn)大于一般賦值操作的耗時(shí)皇忿,因此這部分優(yōu)化可以增大該容器類的并發(fā)量。而剩下部分的判斷都是為了在進(jìn)行這一步優(yōu)化的情況下坦仍,保證程序的正確性所做的鳍烁。
出隊(duì)流程
單線程情況下的出隊(duì)流程為:
- 如果head==tail,證明隊(duì)列為空繁扎,返回null
- 將隊(duì)首元素的值取出幔荒,作為返回值
- 將head指向head.next
如果按照這種思路糊闽,我們可以直接寫出一個(gè)簡(jiǎn)單寫法的無鎖出隊(duì)方案:
public E poll() {
for (; ; ) {
Node<E> h = head;
if (h.next == null) {
return null;
} else {
if (casHead(h, h.next)) {
if (h.next != null)
return h.next.item;
}
}
}
}
我們?cè)賮砜碕DK源碼中的poll函數(shù)實(shí)現(xiàn),在這個(gè)poll函數(shù)中爹梁,使用了和offer函數(shù)中類似的優(yōu)化方式右犹,在出隊(duì)的時(shí)候并不著急更新head的值,而是緩慢更新姚垃,然后用一部分操作來保證出隊(duì)的正確性:
public E poll() {
restartFromHead:
for (; ; ) {
for (Node<E> h = head, p = h, q; ; ) {
E item = p.item;
if (item != null && p.casItem(item, null)) {
if (p != h)
updateHead(h, ((q = p.next) != null) ? q : p);
return item;
} else if ((q = p.next) == null) {
updateHead(h, p);
return null;
} else if (p == q)
continue restartFromHead;
else
p = q;
}
}
}
性能測(cè)試
這里不光是性能測(cè)試念链,同樣有針對(duì)上述兩種簡(jiǎn)單的無鎖入隊(duì)和出隊(duì)的正確性測(cè)試。我分別開了2根入隊(duì)線程和2根出隊(duì)線程积糯,每根入隊(duì)線程循環(huán)入隊(duì)1000W的數(shù)據(jù)掂墓,下面展示了測(cè)試結(jié)果(因?yàn)槲业碾娔X是4核i5,比較弱雞看成,如果線程開多了那么大量的時(shí)間都在線程切換上君编,測(cè)試結(jié)果就不準(zhǔn)確了):
使用JDK源碼
Test Started: 11:15 25:839
Get thread finished, Total: 10809006
Get thread finished, Total: 9190994
Test Finished: 11:15 31:487
Total Time Cost: 5s 648ms
使用自定義的offer函數(shù)
Test Started: 11:17 36:963
Get thread finished, Total: 9335745
Get thread finished, Total: 10664255
Test Finished: 11:17 41:627
Total Time Cost: 4s 664ms
使用自定義的poll函數(shù)
Test Started: 11:18 17:412
Get thread finished, Total: 9714954
Get thread finished, Total: 10285046
Test Finished: 11:18 21:669
Total Time Cost: 4s 257ms
同時(shí)使用自定義的offer和poll函數(shù)
Test Started: 11:18 51:663
Get thread finished, Total: 10219132
Get thread finished, Total: 9780868
Test Finished: 11:18 56:602
Total Time Cost: 4s 939ms
有點(diǎn)尷尬的是好像優(yōu)化過的源碼是跑的最慢的,應(yīng)該和我只有2根讀寫線程有關(guān)川慌,爭(zhēng)搶的情況比較少吃嘿,爭(zhēng)搶情況越嚴(yán)重,線程越多梦重,源碼的速度應(yīng)該是更快的兑燥。如果誰有更好的機(jī)器可以拿代碼試一下,下面是我的測(cè)試代碼:
public class TestQueue {
private static int TOTAL_COUNT = 10000000;
private static int TOTAL_WRITE = 2;
private static int TOTAL_READ = 2;
private static SimpleDateFormat DATE_FORMAT = new SimpleDateFormat("HH:mm ss:SSS");
public static void main(String[] args) {
AtomicInteger flag = new AtomicInteger(0);
ConcurrentHashMap<Integer, AtomicInteger> total = new ConcurrentHashMap<>(TOTAL_COUNT);
for (int i = 0; i != TOTAL_COUNT; i++) {
total.put(i, new AtomicInteger(0));
}
CustomQueue<Integer> customQueue = new CustomQueue<>();
ExecutorService executor = Executors.newCachedThreadPool();
Date startTime = new Date();
System.out.println("Test Started: " + DATE_FORMAT.format(startTime));
for (int i = 0; i != TOTAL_WRITE; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
for (int i = 0; i != TOTAL_COUNT; i++) {
customQueue.add(i);
}
}
});
}
for (int i = 0; i != TOTAL_READ; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
int sum = 0;
while (flag.get() != TOTAL_WRITE * TOTAL_COUNT) {
Integer num = customQueue.poll();
if (num != null) {
sum++;
flag.incrementAndGet();
total.get(num).incrementAndGet();
}
}
System.out.println("Get thread finished, Total: " + sum);
}
});
}
executor.shutdown();
try {
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
Date endTime = new Date();
long totalTime = endTime.getTime() - startTime.getTime();
for (int i = 0; i != TOTAL_COUNT; i++) {
if (total.get(i).get() != TOTAL_WRITE) {
System.out.println("Test Failed: " + i + " " + total.get(i));
break;
}
}
System.out.println("Test Finished: " + DATE_FORMAT.format(endTime));
System.out.printf("Total Time Cost: %ds %dms", totalTime / 1000, totalTime % 1000);
} catch (InterruptedException e) {
System.out.println("Failure: " + flag.get());
e.printStackTrace();
}
}
}