package org . apache . ibatis . block ; import java . util . HashMap ; import java . util . Map ; import java . util . concurrent . ConcurrentHashMap ; import java . util . concurrent . ConcurrentMap ; import java . util . concurrent . Time
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* 阻塞 Map
*/
public class BlockMap {
private Map<Object, Object> map = new HashMap<>();
/**
* 键是 key,值是 ReentrantLock 锁
*/
private final ConcurrentMap<Object, ReentrantLock> locks = new ConcurrentHashMap<>();
/**
* 获取锁的超时时间
*/
private long timeout;
public long getTimeout() {
return timeout;
}
public void setTimeout(long timeout) {
this.timeout = timeout;
}
public Object getObject(Object key) {
acquireLock(key);
Object value = map.get(key);
if (value != null) {
//释放锁
releaseLock(key);
}
return value;
}
public void putObject(Object key, Object value) {
try {
map.put(key, value);
} finally {
//注意:这里释放锁的线程必须和获取锁的是同一个线程!!!否则会报错
releaseLock(key);
}
}
private void acquireLock(Object key) {
Lock lock = getLockForKey(key);
if (timeout > 0) {
try {
//尝试获取锁,带超时时长
boolean acquired = lock.tryLock(timeout, TimeUnit.MILLISECONDS);
if (!acquired) {
throw new RuntimeException("获取锁超时!");
}
} catch (InterruptedException e) {
e.printStackTrace();
}
} else {
//获取锁,不带超时时长
lock.lock();
}
}
private Lock getLockForKey(Object key) {
ReentrantLock lock = new ReentrantLock();
//使用 putIfAbsent() ,保证原子性
ReentrantLock previousLock = locks.putIfAbsent(key, lock);
return previousLock == null ? lock : previousLock;
}
private void releaseLock(Object key) {
ReentrantLock lock = locks.get(key);
//若锁为当前线程持有,则释放锁
if (lock != null && lock.isHeldByCurrentThread()) {
lock.unlock();
}
}
public static void main(String[] args) {
testLock();
}
private static void testLock() {
BlockMap map = new BlockMap();
new Thread(() -> {
System.out.println("线程1获取值");
Object value = map.getObject("1");
if (value == null) {
System.out.println("线程1获取值为空,尝试从数据库获取!");
map.putObject("1", 1);
System.out.println("线程1从数据库获取值[" + map.getObject("1") + "]并放入缓存!");
} else {
System.out.println(value);
}
}).start();
new Thread(() -> {
System.out.println("线程2获取值");
System.out.println(map.getObject("1"));
}).start();
new Thread(() -> {
System.out.println("线程3获取值");
System.out.println(map.getObject("1"));
}).start();
}
}