blob: 411dca02b2b600d6179cbf884eeb37098bb0687e [file] [log] [blame]
// Copyright 2014 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
/// A collection of key/value pairs which provides efficient retrieval of
/// value by key.
///
/// This class implements a persistent map: extending this map with a new
/// key/value pair does not modify an existing instance but instead creates a
/// new instance.
///
/// Unlike [Map], this class does not support `null` as a key value and
/// implements only a functionality needed for a specific use case at the
/// core of the framework.
///
/// Underlying implementation uses a variation of *hash array mapped trie*
/// data structure with compressed (bitmap indexed) nodes.
///
/// See also:
///
/// * [Bagwell, Phil. Ideal hash trees.](https://infoscience.epfl.ch/record/64398);
/// * [Steindorfer, Michael J., and Jurgen J. Vinju. "Optimizing hash-array mapped tries for fast and lean immutable JVM collections."](https://dl.acm.org/doi/abs/10.1145/2814270.2814312);
/// * [Clojure's `PersistentHashMap`](https://github.com/clojure/clojure/blob/master/src/jvm/clojure/lang/PersistentHashMap.java).
///
class PersistentHashMap<K extends Object, V> {
/// Creates an empty hash map.
const PersistentHashMap.empty() : this._(null);
const PersistentHashMap._(this._root);
final _TrieNode? _root;
/// If this map does not already contain the given [key] to [value]
/// mapping then create a new version of the map which contains
/// all mappings from the current one plus the given [key] to [value]
/// mapping.
PersistentHashMap<K, V> put(K key, V value) {
final _TrieNode newRoot =
(_root ?? _CompressedNode.empty).put(0, key, key.hashCode, value);
if (newRoot == _root) {
return this;
}
return PersistentHashMap<K, V>._(newRoot);
}
/// Returns value associated with the given [key] or `null` if [key]
/// is not in the map.
@pragma('dart2js:as:trust')
V? operator[](K key) {
if (_root == null) {
return null;
}
// Unfortunately can not use unsafeCast<V?>(...) here because it leads
// to worse code generation on VM.
return _root!.get(0, key, key.hashCode) as V?;
}
}
/// Base class for nodes in a hash trie.
///
/// This trie is keyed by hash code bits using [hashBitsPerLevel] bits
/// at each level.
abstract class _TrieNode {
static const int hashBitsPerLevel = 5;
static const int hashBitsPerLevelMask = (1 << hashBitsPerLevel) - 1;
@pragma('vm:prefer-inline')
static int trieIndex(int hash, int bitIndex) {
return (hash >>> bitIndex) & hashBitsPerLevelMask;
}
/// Insert [key] to [value] mapping into the trie using bits from [keyHash]
/// starting at [bitIndex].
_TrieNode put(int bitIndex, Object key, int keyHash, Object? value);
/// Lookup a value associated with the given [key] using bits from [keyHash]
/// starting at [bitIndex].
Object? get(int bitIndex, Object key, int keyHash);
}
/// A full (uncompressed) node in the trie.
///
/// It contains an array with `1<<_hashBitsPerLevel` elements which
/// are references to deeper nodes.
class _FullNode extends _TrieNode {
_FullNode(this.descendants);
static const int numElements = 1 << _TrieNode.hashBitsPerLevel;
// Caveat: this array is actually List<_TrieNode?> but typing it like that
// will introduce a type check when copying this array. For performance
// reasons we instead omit the type and use (implicit) casts when accessing
// it instead.
final List<Object?> descendants;
@override
_TrieNode put(int bitIndex, Object key, int keyHash, Object? value) {
final int index = _TrieNode.trieIndex(keyHash, bitIndex);
final _TrieNode node = _unsafeCast<_TrieNode?>(descendants[index]) ?? _CompressedNode.empty;
final _TrieNode newNode = node.put(bitIndex + _TrieNode.hashBitsPerLevel, key, keyHash, value);
return identical(newNode, node)
? this
: _FullNode(_copy(descendants)..[index] = newNode);
}
@override
Object? get(int bitIndex, Object key, int keyHash) {
final int index = _TrieNode.trieIndex(keyHash, bitIndex);
final _TrieNode? node = _unsafeCast<_TrieNode?>(descendants[index]);
return node?.get(bitIndex + _TrieNode.hashBitsPerLevel, key, keyHash);
}
}
/// Compressed node in the trie.
///
/// Instead of storing the full array of outgoing edges this node uses a
/// compressed representation:
///
/// * [_CompressedNode.occupied] has a bit set for indices which are occupied.
/// * furthermore, each occupied index can either be a `(key, value)` pair
/// representing an actual key/value mapping or a `(null, trieNode)` pair
/// representing a descendant node.
///
/// Keys and values are stored together in a single array (instead of two
/// parallel arrays) for performance reasons: this improves memory access
/// locality and reduces memory usage (two arrays of length N take slightly
/// more space than one array of length 2*N).
class _CompressedNode extends _TrieNode {
_CompressedNode(this.occupiedIndices, this.keyValuePairs);
_CompressedNode._empty() : this(0, _emptyArray);
factory _CompressedNode.single(int bitIndex, int keyHash, _TrieNode node) {
final int bit = 1 << _TrieNode.trieIndex(keyHash, bitIndex);
// A single (null, node) pair.
final List<Object?> keyValuePairs = _makeArray(2)
..[1] = node;
return _CompressedNode(bit, keyValuePairs);
}
static final _CompressedNode empty = _CompressedNode._empty();
// Caveat: do not replace with <Object?>[] or const <Object?>[] this will
// introduce polymorphism in the keyValuePairs field and significantly
// degrade performance on the VM because it will no longer be able to
// devirtualize method calls on keyValuePairs.
static final List<Object?> _emptyArray = _makeArray(0);
// This bitmap only uses 32bits due to [_TrieNode.hashBitsPerLevel] being `5`.
final int occupiedIndices;
final List<Object?> keyValuePairs;
@override
_TrieNode put(int bitIndex, Object key, int keyHash, Object? value) {
final int bit = 1 << _TrieNode.trieIndex(keyHash, bitIndex);
final int index = _compressedIndex(bit);
if ((occupiedIndices & bit) != 0) {
// Index is occupied.
final Object? keyOrNull = keyValuePairs[2 * index];
final Object? valueOrNode = keyValuePairs[2 * index + 1];
// Is this a (null, trieNode) pair?
if (identical(keyOrNull, null)) {
final _TrieNode newNode = _unsafeCast<_TrieNode>(valueOrNode).put(
bitIndex + _TrieNode.hashBitsPerLevel, key, keyHash, value);
if (newNode == valueOrNode) {
return this;
}
return _CompressedNode(
occupiedIndices, _copy(keyValuePairs)..[2 * index + 1] = newNode);
}
if (key == keyOrNull) {
// Found key/value pair with a matching key. If values match
// then avoid doing anything otherwise copy and update.
return identical(value, valueOrNode)
? this
: _CompressedNode(
occupiedIndices, _copy(keyValuePairs)..[2 * index + 1] = value);
}
// Two different keys at the same index, resolve collision.
final _TrieNode newNode = _resolveCollision(
bitIndex + _TrieNode.hashBitsPerLevel,
keyOrNull,
valueOrNode,
key,
keyHash,
value);
return _CompressedNode(
occupiedIndices,
_copy(keyValuePairs)
..[2 * index] = null
..[2 * index + 1] = newNode);
} else {
// Adding new key/value mapping.
final int occupiedCount = _bitCount(occupiedIndices);
if (occupiedCount >= 16) {
// Too many occupied: inflate compressed node into full node and
// update descendant at the corresponding index.
return _inflate(bitIndex)
..descendants[_TrieNode.trieIndex(keyHash, bitIndex)] =
_CompressedNode.empty.put(
bitIndex + _TrieNode.hashBitsPerLevel, key, keyHash, value);
} else {
// Grow keyValuePairs by inserting key/value pair at the given
// index.
final int prefixLength = 2 * index;
final int totalLength = 2 * occupiedCount;
final List<Object?> newKeyValuePairs = _makeArray(totalLength + 2);
for (int srcIndex = 0; srcIndex < prefixLength; srcIndex++) {
newKeyValuePairs[srcIndex] = keyValuePairs[srcIndex];
}
newKeyValuePairs[prefixLength] = key;
newKeyValuePairs[prefixLength + 1] = value;
for (int srcIndex = prefixLength, dstIndex = prefixLength + 2;
srcIndex < totalLength;
srcIndex++, dstIndex++) {
newKeyValuePairs[dstIndex] = keyValuePairs[srcIndex];
}
return _CompressedNode(occupiedIndices | bit, newKeyValuePairs);
}
}
}
@override
Object? get(int bitIndex, Object key, int keyHash) {
final int bit = 1 << _TrieNode.trieIndex(keyHash, bitIndex);
if ((occupiedIndices & bit) == 0) {
return null;
}
final int index = _compressedIndex(bit);
final Object? keyOrNull = keyValuePairs[2 * index];
final Object? valueOrNode = keyValuePairs[2 * index + 1];
if (keyOrNull == null) {
final _TrieNode node = _unsafeCast<_TrieNode>(valueOrNode);
return node.get(bitIndex + _TrieNode.hashBitsPerLevel, key, keyHash);
}
if (key == keyOrNull) {
return valueOrNode;
}
return null;
}
/// Convert this node into an equivalent [_FullNode].
_FullNode _inflate(int bitIndex) {
final List<Object?> nodes = _makeArray(_FullNode.numElements);
int srcIndex = 0;
for (int dstIndex = 0; dstIndex < _FullNode.numElements; dstIndex++) {
if (((occupiedIndices >>> dstIndex) & 1) != 0) {
final Object? keyOrNull = keyValuePairs[srcIndex];
if (keyOrNull == null) {
nodes[dstIndex] = keyValuePairs[srcIndex + 1];
} else {
nodes[dstIndex] = _CompressedNode.empty.put(
bitIndex + _TrieNode.hashBitsPerLevel,
keyOrNull,
keyValuePairs[srcIndex].hashCode,
keyValuePairs[srcIndex + 1]);
}
srcIndex += 2;
}
}
return _FullNode(nodes);
}
@pragma('vm:prefer-inline')
int _compressedIndex(int bit) {
return _bitCount(occupiedIndices & (bit - 1));
}
static _TrieNode _resolveCollision(int bitIndex, Object existingKey,
Object? existingValue, Object key, int keyHash, Object? value) {
final int existingKeyHash = existingKey.hashCode;
// Check if this is a full hash collision and use _HashCollisionNode
// in this case.
return (existingKeyHash == keyHash)
? _HashCollisionNode.fromCollision(
keyHash, existingKey, existingValue, key, value)
: _CompressedNode.empty
.put(bitIndex, existingKey, existingKeyHash, existingValue)
.put(bitIndex, key, keyHash, value);
}
}
/// Trie node representing a full hash collision.
///
/// Stores a list of key/value pairs (where all keys have the same hash code).
class _HashCollisionNode extends _TrieNode {
_HashCollisionNode(this.hash, this.keyValuePairs);
factory _HashCollisionNode.fromCollision(
int keyHash, Object keyA, Object? valueA, Object keyB, Object? valueB) {
final List<Object?> list = _makeArray(4);
list[0] = keyA;
list[1] = valueA;
list[2] = keyB;
list[3] = valueB;
return _HashCollisionNode(keyHash, list);
}
final int hash;
final List<Object?> keyValuePairs;
@override
_TrieNode put(int bitIndex, Object key, int keyHash, Object? val) {
// Is this another full hash collision?
if (keyHash == hash) {
final int index = _indexOf(key);
if (index != -1) {
return identical(keyValuePairs[index + 1], val)
? this
: _HashCollisionNode(
keyHash, _copy(keyValuePairs)..[index + 1] = val);
}
final int length = keyValuePairs.length;
final List<Object?> newArray = _makeArray(length + 2);
for (int i = 0; i < length; i++) {
newArray[i] = keyValuePairs[i];
}
newArray[length] = key;
newArray[length + 1] = val;
return _HashCollisionNode(keyHash, newArray);
}
// Not a full hash collision, need to introduce a _CompressedNode which
// uses previously unused bits.
return _CompressedNode.single(bitIndex, hash, this)
.put(bitIndex, key, keyHash, val);
}
@override
Object? get(int bitIndex, Object key, int keyHash) {
final int index = _indexOf(key);
return index < 0 ? null : keyValuePairs[index + 1];
}
int _indexOf(Object key) {
final int length = keyValuePairs.length;
for (int i = 0; i < length; i += 2) {
if (key == keyValuePairs[i]) {
return i;
}
}
return -1;
}
}
/// Returns number of bits set in a 32bit integer.
///
/// dart2js safe because we work with 32bit integers.
@pragma('vm:prefer-inline')
@pragma('dart2js:tryInline')
int _bitCount(int n) {
assert((n & 0xFFFFFFFF) == n);
n = n - ((n >> 1) & 0x55555555);
n = (n & 0x33333333) + ((n >>> 2) & 0x33333333);
n = (n + (n >> 4)) & 0x0F0F0F0F;
n = n + (n >> 8);
n = n + (n >> 16);
return n & 0x0000003F;
}
/// Create a copy of the given array.
///
/// Caveat: do not replace with List.of or similar methods. They are
/// considerably slower.
@pragma('vm:prefer-inline')
@pragma('dart2js:tryInline')
List<Object?> _copy(List<Object?> array) {
final List<Object?> clone = _makeArray(array.length);
for (int j = 0; j < array.length; j++) {
clone[j] = array[j];
}
return clone;
}
/// Create a fixed-length array of the given length filled with `null`.
///
/// We are using fixed length arrays because they are smaller and
/// faster to access on VM. Growable arrays are represented by 2 objects
/// (growable array instance pointing to a fixed array instance) and
/// consequently fixed length arrays are faster to allocated, require less
/// memory and are faster to access (less indirections).
@pragma('vm:prefer-inline')
@pragma('dart2js:tryInline')
List<Object?> _makeArray(int length) {
return List<Object?>.filled(length, null);
}
/// This helper method becomes an no-op when compiled with dart2js on
/// with high level of optimizations enabled.
@pragma('dart2js:tryInline')
@pragma('dart2js:as:trust')
@pragma('vm:prefer-inline')
T _unsafeCast<T>(Object? o) {
return o as T;
}