简述
上一篇我们看了常用的三种阻塞队列特点,今天我们自己实现阻塞队列。
代码
直接上代码,主要维护了一个数组和获取元素/放入元素的指针,还有使用了可重入锁ReentrantLock和条件Condition实现。
队列代码
public class MyBlockingQueue<T> {
private final Object[] queue;
private int takeIndex = 0;
private int putIndex = 0;
private final int size;
private int count = 0;
ReentrantLock lock = new ReentrantLock();
Condition empty = lock.newCondition();
Condition full = lock.newCondition();
public MyBlockingQueue(int size) {
this.size = size;
this.queue = new Object[size];
}
public T take() {
lock.lock();
try {
// 出队
while (count == 0) {
empty.await();
}
T value = dequeue();
full.signal();
return value;
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
lock.unlock();
}
return null;
}
public void put(T value) {
lock.lock();
try {
// 进队
while (count == size) {
full.await();
}
enqueue(value);
empty.signal();
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
private T dequeue() {
Object value = queue[takeIndex];
takeIndex++;
// 如果为最后一位重置index
if (takeIndex == size) {
takeIndex = 0;
}
// 总数减少
count--;
return (T) value;
}
private void enqueue(T value) {
queue[putIndex] = value;
putIndex++;
// 如果为最后一位重置index
if (putIndex == size) {
putIndex = 0;
}
// 总数增加
count++;
}
}
测试代码
public static void main(String[] args) throws InterruptedException {
// 创建一个大小为2的阻塞队列
final MyBlockingQueue<Integer> q = new MyBlockingQueue<>(2);
// 创建2个线程
final int threads = 400;
// 每个线程执行10次
final int times = 100;
// 线程列表,用于等待所有线程完成
List<Thread> threadList = new ArrayList<>(threads * 2);
long startTime = System.currentTimeMillis();
// 创建2个消费者线程,从队列中弹出20次数字并打印弹出的数字
for (int i = 0; i < threads; ++i) {
Thread consumer = new Thread(() -> {
try {
for (int j = 0; j < times; ++j) {
Integer element = q.take();
System.out.println(element);
}
} catch (Exception e) {
e.printStackTrace();
}
});
threadList.add(consumer);
}
for (int i = 0; i < threads; ++i) {
final int offset = i * times;
Thread producer = new Thread(() -> {
try {
for (int j = 0; j < times; ++j) {
q.put(offset + j);
}
} catch (Exception e) {
e.printStackTrace();
}
});
threadList.add(producer);
}
for (Thread thread : threadList) {
thread.start();
}
// 等待所有线程执行完成
for (Thread thread : threadList) {
thread.join();
}
// 打印运行耗时
long endTime = System.currentTimeMillis();
System.out.printf("总耗时:%.2fs%n", (endTime - startTime) / 1e3);
}
评论区