Computer Science Study Notes Help

Part Ⅱ

9 Symbol Table & Binary Search Tree

9.1 Symbol Table & Elementary Implementation

Symbol table: Key-value pair abstraction.

  • Insert a value with a specified key.

  • Given a key, search for the corresponding value.

Method: Maintain an (unordered) linked list of key-value pairs.

  • Search: Scan through all keys until find a match (sequential search).

  • Insert: Scan through all keys until find a match; if no match add to front.

import java.util.ArrayList; import java.util.List; public class SequentialSearchST<Key, Value> { private int n; private Node first; private class Node { private final Key key; private Value val; private Node next; public Node(Key key, Value val, Node next) { this.key = key; this.val = val; this.next = next; } } public SequentialSearchST() { } public int size() { return n; } public boolean isEmpty() { return size() == 0; } public boolean contains(Key key) { return get(key) != null; } public Value get(Key key) { for (Node x = first; x != null; x = x.next) { if (key.equals(x.key)) return x.val; } return null; } public void put(Key key, Value val) { if (val == null) { delete(key); return; } for (Node x = first; x != null; x = x.next) { if (key.equals(x.key)) { x.val = val; return; } } first = new Node(key, val, first); n++; } public void delete(Key key) { first = delete(first, key); } private Node delete(Node x, Key key) { if (x == null) return null; if (key.equals(x.key)) { n--; return x.next; } x.next = delete(x.next, key); return x; } public Iterable<Key> keys() { List<Key> list = new ArrayList<>(); for (Node x = first; x != null; x = x.next) list.add(x.key); return list; } }
#include <iostream> #include <vector> template <typename Key, typename Value> class SequentialSearchST { private: struct Node { Key key; Value val; Node* next; Node(Key key, Value val, Node* next) : key(key), val(val), next(next) {} }; Node* first; int n; public: SequentialSearchST() : first(nullptr), n(0) {} [[nodiscard]] int size() const { return n; } [[nodiscard]] bool isEmpty() const { return size() == 0; } bool contains(const Key& key) { Node* x = first; while (x != nullptr) { if (x->key == key) { return true; } x = x->next; } return false; } Value get(const Key& key) { Node* x = first; while (x != nullptr) { if (x->key == key) { return x->val; } x = x->next; } throw std::runtime_error("Key not found"); } void put(const Key& key, const Value& val) { Node* x = first; while (x != nullptr) { if (x->key == key) { x->val = val; return; } x = x->next; } first = new Node(key, val, first); n++; } void remove(const Key& key) { first = remove(first, key); } Node* remove(Node* x, const Key& key) { if (x == nullptr) { return nullptr; } if (x->key == key) { n--; Node* temp = x->next; delete x; return temp; } x->next = remove(x->next, key); return x; } std::vector<Key> keys() { std::vector<Key> keys; Node* x = first; while (x != nullptr) { keys.push_back(x->key); x = x->next; } return keys; } };
class SequentialSearchST: class Node: def __init__(self, key, val, next_node=None): self.key = key self.val = val self.next = next_node def __init__(self): self.n = 0 self.first = None def size(self): return self.n def is_empty(self): return self.size() == 0 def contains(self, key): return self.get(key) is not None def get(self, key): x = self.first while x is not None: if key == x.key: return x.val x = x.next return None def put(self, key, val): if val is None: self.delete(key) return x = self.first while x is not None: if key == x.key: x.val = val return x = x.next self.first = self.Node(key, val, self.first) self.n += 1 def delete(self, key): self.first = self._delete(self.first, key) def _delete(self, x, key): if x is None: return None if key == x.key: self.n -= 1 return x.next x.next = self._delete(x.next, key) return x def keys(self): keys_list = [] x = self.first while x is not None: keys_list.append(x.key) x = x.next return keys_list

9.1.2 Ordered Array Implementation

Method: Maintain an ordered array of key-value pairs.

  • Search: Binary search

  • Insert: Need to shift all greater keys over.

import java.util.Arrays; import java.util.NoSuchElementException; import java.util.LinkedList; public class BinarySearchST<Key extends Comparable<Key>, Value> { private static final int INIT_CAPACITY = 2; private Key[] keys; private Value[] vals; private int n = 0; public BinarySearchST() { this(INIT_CAPACITY); } public BinarySearchST(int capacity) { keys = (Key[]) new Comparable[capacity]; vals = (Value[]) new Object[capacity]; } private void resize(int capacity) { assert capacity >= n; Key[] tempk = (Key[]) new Comparable[capacity]; Value[] tempv = (Value[]) new Object[capacity]; for (int i = 0; i < n; i++) { tempk[i] = keys[i]; tempv[i] = vals[i]; } vals = tempv; keys = tempk; } public int size() { return n; } public boolean isEmpty() { return size() == 0; } public boolean contains(Key key) { if (key == null) throw new IllegalArgumentException("argument to contains() is null"); return get(key) != null; } public Value get(Key key) { if (key == null) throw new IllegalArgumentException("argument to get() is null"); if (isEmpty()) return null; int i = rank(key); if (i < n && keys[i].compareTo(key) == 0) return vals[i]; return null; } public int rank(Key key) { if (key == null) throw new IllegalArgumentException("argument to rank() is null"); int lo = 0, hi = n - 1; while (lo <= hi) { int mid = lo + (hi - lo) / 2; int cmp = key.compareTo(keys[mid]); if (cmp < 0) hi = mid - 1; else if (cmp > 0) lo = mid + 1; else return mid; } return lo; } public void put(Key key, Value val) { if (key == null) throw new IllegalArgumentException("first argument to put() is null"); if (val == null) { delete(key); return; } int i = rank(key); if (i < n && keys[i].compareTo(key) == 0) { vals[i] = val; return; } if (n == keys.length) resize(2 * keys.length); for (int j = n; j > i; j--) { keys[j] = keys[j - 1]; vals[j] = vals[j - 1]; } keys[i] = key; vals[i] = val; n++; assert check(); } public void delete(Key key) { if (key == null) throw new IllegalArgumentException("argument to delete() is null"); if (isEmpty()) return; int i = rank(key); if (i == n || keys[i].compareTo(key) != 0) { return; } for (int j = i; j < n - 1; j++) { keys[j] = keys[j + 1]; vals[j] = vals[j + 1]; } n--; keys[n] = null; vals[n] = null; if (n > 0 && n == keys.length / 4) resize(keys.length / 2); assert check(); } public void deleteMin() { if (isEmpty()) throw new NoSuchElementException("Symbol table underflow error"); delete(min()); } public void deleteMax() { if (isEmpty()) throw new NoSuchElementException("Symbol table underflow error"); delete(max()); } public Key min() { if (isEmpty()) throw new NoSuchElementException("called min() with empty symbol table"); return keys[0]; } public Key max() { if (isEmpty()) throw new NoSuchElementException("called max() with empty symbol table"); return keys[n - 1]; } public Key select(int k) { if (k < 0 || k >= size()) { throw new IllegalArgumentException("called select() with invalid argument: " + k); } return keys[k]; } public Key floor(Key key) { if (key == null) throw new IllegalArgumentException("argument to floor() is null"); int i = rank(key); if (i < n && key.compareTo(keys[i]) == 0) return keys[i]; if (i == 0) throw new NoSuchElementException("argument to floor() is too small"); else return keys[i - 1]; } public Key ceiling(Key key) { if (key == null) throw new IllegalArgumentException("argument to ceiling() is null"); int i = rank(key); if (i == n) throw new NoSuchElementException("argument to ceiling() is too large"); else return keys[i]; } public int size(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to size() is null"); if (hi == null) throw new IllegalArgumentException("second argument to size() is null"); if (lo.compareTo(hi) > 0) return 0; if (contains(hi)) return rank(hi) - rank(lo) + 1; else return rank(hi) - rank(lo); } public Iterable<Key> keys() { return keys(min(), max()); } public Iterable<Key> keys(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to keys() is null"); if (hi == null) throw new IllegalArgumentException("second argument to keys() is null"); LinkedList<Key> queue = new LinkedList<>(); if (lo.compareTo(hi) > 0) return queue; queue.addAll(Arrays.asList(keys).subList(rank(lo), rank(hi))); if (contains(hi)) queue.add(keys[rank(hi)]); return queue; } private boolean check() { return isSorted() && rankCheck(); } private boolean isSorted() { for (int i = 1; i < size(); i++) if (keys[i].compareTo(keys[i - 1]) < 0) return false; return true; } private boolean rankCheck() { for (int i = 0; i < size(); i++) if (i != rank(select(i))) return false; for (int i = 0; i < size(); i++) if (keys[i].compareTo(select(rank(keys[i]))) != 0) return false; return true; } public static void main(String[] args) { BinarySearchST<String, Integer> st = new BinarySearchST<>(); String[] input = {"S", "E", "A", "R", "C", "H", "E", "X", "A", "M", "P", "L", "E"}; for (int i = 0; i < input.length; i++) { String key = input[i]; st.put(key, i); } for (String s : st.keys()) System.out.println(s + " " + st.get(s)); } }
#include <iostream> #include <vector> #include <cassert> #include <stdexcept> #include <optional> template <typename Key, typename Value> class BinarySearchST { private: static constexpr int INIT_CAPACITY = 2; std::vector<Key> keys; std::vector<Value> vals; int n; const Value MISSING_VALUE = -1; void resize(int capacity) { assert(capacity >= n); std::vector<Key> tempk(capacity); std::vector<Value> tempv(capacity); for (int i = 0; i < n; i++) { tempk[i] = keys[i]; tempv[i] = vals[i]; } vals = tempv; keys = tempk; } public: BinarySearchST() : BinarySearchST(INIT_CAPACITY) {} explicit BinarySearchST(int capacity) : keys(capacity), vals(capacity), n(0) {} [[nodiscard]] int size() const { return n; } [[nodiscard]] bool isEmpty() const { return size() == 0; } [[nodiscard]] bool contains(const Key& key) const { return get(key).has_value(); } [[nodiscard]] std::optional<Value> get(const Key& key) const { if (isEmpty()) return std::nullopt; int i = rank(key); if (i < n && keys[i] == key) return vals[i]; return std::nullopt; } [[nodiscard]] int rank(const Key& key) const { int lo = 0, hi = n - 1; while (lo <= hi) { int mid = lo + (hi - lo) / 2; if (key < keys[mid]) hi = mid - 1; else if (key > keys[mid]) lo = mid + 1; else return mid; } return lo; } void put(const Key& key, const Value& val) { if (val == MISSING_VALUE) { delete_(key); return; } int i = rank(key); if (i < n && keys[i] == key) { vals[i] = val; return; } if (n == keys.size()) resize(2 * keys.size()); for (int j = n; j > i; j--) { keys[j] = keys[j - 1]; vals[j] = vals[j - 1]; } keys[i] = key; vals[i] = val; n++; assert(check()); } void delete_(const Key& key) { if (isEmpty()) return; int i = rank(key); if (i == n || keys[i] != key) { return; } for (int j = i; j < n - 1; ++j) { keys[j] = keys[j + 1]; vals[j] = vals[j + 1]; } n--; if (n > 0 && n == keys.size() / 4) resize(keys.size() / 2); assert(check()); } void deleteMin() { if (isEmpty()) throw std::runtime_error("Symbol table underflow error"); delete_(min()); } void deleteMax() { if (isEmpty()) throw std::runtime_error("Symbol table underflow error"); delete_(max()); } [[nodiscard]] Key min() const { if (isEmpty()) throw std::runtime_error("called min() with empty symbol table"); return keys[0]; } [[nodiscard]] Key max() const { if (isEmpty()) throw std::runtime_error("called max() with empty symbol table"); return keys[n - 1]; } [[nodiscard]] Key select(int k) const { if (k < 0 || k >= size()) { throw std::invalid_argument("called select() with invalid argument: " + std::to_string(k)); } return keys[k]; } [[nodiscard]] Key floor(const Key& key) const { int i = rank(key); if (i < n && key == keys[i]) return keys[i]; if (i == 0) throw std::runtime_error("argument to floor() is too small"); else return keys[i - 1]; } [[nodiscard]] Key ceiling(const Key& key) const { int i = rank(key); if (i == n) throw std::runtime_error("argument to ceiling() is too large"); else return keys[i]; } [[nodiscard]] int size(const Key& lo, const Key& hi) const { if (lo > hi) return 0; if (contains(hi)) return rank(hi) - rank(lo) + 1; else return rank(hi) - rank(lo); } [[nodiscard]] std::vector<Key> getkeys() const { return getkeys(min(), max()); } [[nodiscard]] std::vector<Key> getkeys(const Key& lo, const Key& hi) const { std::vector<Key> queue; if (lo > hi) return queue; for (int i = rank(lo); i < rank(hi); ++i) queue.push_back(keys[i]); if (contains(hi)) queue.push_back(keys[rank(hi)]); return queue; } private: [[nodiscard]] bool check() const { return isSorted() && rankCheck(); } [[nodiscard]] bool isSorted() const { for (int i = 1; i < size(); i++) if (keys[i] < keys[i - 1]) return false; return true; } [[nodiscard]] bool rankCheck() const { for (int i = 0; i < size(); i++) if (i != rank(select(i))) return false; for (int i = 0; i < size(); i++) if (keys[i] != select(rank(keys[i]))) return false; return true; } };
class BinarySearchST: INIT_CAPACITY = 2 MISSING_VALUE = -1 def __init__(self, capacity=INIT_CAPACITY): self.keys = [None] * capacity self.vals = [None] * capacity self.n = 0 def size(self): return self.n def isEmpty(self): return self.size() == 0 def contains(self, key): return self.get(key) is not None def get(self, key): if self.isEmpty(): return None i = self.rank(key) if i < self.n and self.keys[i] == key: return self.vals[i] return None def rank(self, key): lo = 0 hi = self.n - 1 while lo <= hi: mid = lo + (hi - lo) // 2 if key < self.keys[mid]: hi = mid - 1 elif key > self.keys[mid]: lo = mid + 1 else: return mid return lo def put(self, key, val): if val == self.MISSING_VALUE: self.delete(key) return i = self.rank(key) if i < self.n and self.keys[i] == key: self.vals[i] = val return if self.n == len(self.keys): self.resize(2 * len(self.keys)) for j in range(self.n, i, -1): self.keys[j] = self.keys[j - 1] self.vals[j] = self.vals[j - 1] self.keys[i] = key self.vals[i] = val self.n += 1 assert self.check() def delete(self, key): if self.isEmpty(): return i = self.rank(key) if i == self.n or self.keys[i] != key: return for j in range(i, self.n - 1): self.keys[j] = self.keys[j + 1] self.vals[j] = self.vals[j + 1] self.n -= 1 self.keys[self.n] = None self.vals[self.n] = None if self.n > 0 and self.n == len(self.keys) // 4: self.resize(len(self.keys) // 2) assert self.check() def deleteMin(self): if self.isEmpty(): raise Exception("Symbol table underflow error") self.delete(self.min()) def deleteMax(self): if self.isEmpty(): raise Exception("Symbol table underflow error") self.delete(self.max()) def min(self): if self.isEmpty(): return return self.keys[0] def max(self): if self.isEmpty(): return return self.keys[self.n - 1] def select(self, k): if k < 0 or k >= self.size(): raise ValueError(f"called select() with invalid argument: {k}") return self.keys[k] def floor(self, key): i = self.rank(key) if i < self.n and key == self.keys[i]: return self.keys[i] if i == 0: raise Exception("argument to floor() is too small") else: return self.keys[i - 1] def ceiling(self, key): i = self.rank(key) if i == self.n: raise Exception("argument to ceiling() is too large") else: return self.keys[i] def size_range(self, lo, hi): if lo > hi: return 0 if self.contains(hi): return self.rank(hi) - self.rank(lo) + 1 else: return self.rank(hi) - self.rank(lo) def getkeys(self): if self.isEmpty(): return [] return self.keys_range(self.min(), self.max()) def keys_range(self, lo, hi): queue = [] if lo > hi: return queue for i in range(self.rank(lo), self.rank(hi)): queue.append(self.keys[i]) if self.contains(hi): queue.append(self.keys[self.rank(hi)]) return queue def resize(self, capacity): assert capacity >= self.n tempk = [None] * capacity tempv = [None] * capacity for i in range(self.n): tempk[i] = self.keys[i] tempv[i] = self.vals[i] self.vals = tempv self.keys = tempk def check(self): return self.isSorted() and self.rankCheck() def isSorted(self): for i in range(1, self.size()): if self.keys[i] < self.keys[i - 1]: return False return True def rankCheck(self): for i in range(self.size()): if i != self.rank(self.select(i)): return False for i in range(self.size()): if self.keys[i] != self.select(self.rank(self.keys[i])): return False return True

9.2 Ordered Operation

Provide an interface that can give clients ordered symbol tables!

import java.util.Iterator; import java.util.NoSuchElementException; import java.util.TreeMap; public class ST<Key extends Comparable<Key>, Value> implements Iterable<Key> { private final TreeMap<Key, Value> st; public ST() { st = new TreeMap<Key, Value>(); } public Value get(Key key) { if (key == null) throw new IllegalArgumentException("called get() with null key"); return st.get(key); } public void put(Key key, Value val) { if (key == null) throw new IllegalArgumentException("called put() with null key"); if (val == null) st.remove(key); else st.put(key, val); } @Deprecated public void delete(Key key) { if (key == null) throw new IllegalArgumentException("called delete() with null key"); st.remove(key); } public void remove(Key key) { if (key == null) throw new IllegalArgumentException("called remove() with null key"); st.remove(key); } public boolean contains(Key key) { if (key == null) throw new IllegalArgumentException("called contains() with null key"); return st.containsKey(key); } public int size() { return st.size(); } public boolean isEmpty() { return size() == 0; } public Iterable<Key> keys() { return st.keySet(); } @Deprecated public Iterator<Key> iterator() { return st.keySet().iterator(); } public Key min() { if (isEmpty()) throw new NoSuchElementException("called min() with empty symbol table"); return st.firstKey(); } public Key max() { if (isEmpty()) throw new NoSuchElementException("called max() with empty symbol table"); return st.lastKey(); } public Key ceiling(Key key) { if (key == null) throw new IllegalArgumentException("called ceiling() with null key"); Key k = st.ceilingKey(key); if (k == null) throw new NoSuchElementException("all keys are less than " + key); return k; } public Key floor(Key key) { if (key == null) throw new IllegalArgumentException("called floor() with null key"); Key k = st.floorKey(key); if (k == null) throw new NoSuchElementException("all keys are greater than " + key); return k; } }
#include <iostream> #include <map> #include <vector> #include <stdexcept> template <typename Key, typename Value> class ST { private: std::map<Key, Value> st; public: ST() = default; [[nodiscard]] Value get(const Key& key) const { auto it = st.find(key); if (it == st.end()) { return Value{}; } return it->second; } void put(const Key& key, const Value& val) { st.insert_or_assign(key, val); } void remove(const Key& key) { st.erase(key); } [[nodiscard]] bool contains(const Key& key) const { return st.contains(key); } [[nodiscard]] int size() const { return st.size(); } [[nodiscard]] bool isEmpty() const { return size() == 0; } [[nodiscard]] auto keys() const { std::vector<Key> keysVec; for (const auto& pair : st) { keysVec.push_back(pair.first); } return keysVec; } [[nodiscard]] auto begin() const { return st.begin(); } [[nodiscard]] auto end() const { return st.end(); } [[nodiscard]] const Key& min() const { if (isEmpty()) { throw std::runtime_error("called min() with empty symbol table"); } return st.begin()->first; } [[nodiscard]] const Key& max() const { if (isEmpty()) { throw std::runtime_error("called max() with empty symbol table"); } return st.rbegin()->first; } [[nodiscard]] const Key& ceiling(const Key& key) const { auto it = st.lower_bound(key); return it->first; } [[nodiscard]] const Key& floor(const Key& key) const { auto it = st.upper_bound(key); --it; return it->first; } };
class ST: def __init__(self): self.st = {} def get(self, key): if key is None: raise ValueError("called get() with null key") return self.st.get(key) def put(self, key, val): if key is None: raise ValueError("called put() with null key") if val is None: del self.st[key] else: self.st[key] = val def remove(self, key): if key is None: raise ValueError("called remove() with null key") del self.st[key] def contains(self, key): if key is None: raise ValueError("called contains() with null key") return key in self.st def size(self): return len(self.st) def is_empty(self): return self.size() == 0 def keys(self): return list(self.st.keys()) def __iter__(self): return iter(self.st.keys()) def min(self): if self.is_empty(): raise RuntimeError("called min() with empty symbol table") return min(self.st.keys()) def max(self): if self.is_empty(): raise RuntimeError("called max() with empty symbol table") return max(self.st.keys()) def ceiling(self, key): if key is None: raise ValueError("called ceiling() with null key") keys_greater_equal = [k for k in self.st.keys() if k >= key] if not keys_greater_equal: raise RuntimeError("all keys are less than {}".format(key)) return min(keys_greater_equal) def floor(self, key): if key is None: raise ValueError("called floor() with null key") keys_less_equal = [k for k in self.st.keys() if k <= key] if not keys_less_equal: raise RuntimeError("all keys are greater than {}".format(key)) return max(keys_less_equal)

9.3 Binary Search Trees

Binary Saerch Tree: A BST is a binary tree in symmetric order.

A binary tree is either:

  • Empty.

  • Two disjoint binary trees (left and right).

Binary Search Tree

Symmetric order: Each node has a key, and every node's key is:

  • Larger than all keys in the left subtree.

  • Smaller than all keys in the right subtree.

Symmetric Order

BST Search

  1. If less, go left.

  2. If greater, go right.

  3. If equal, search hit.

BST Insertion

  • Search for keys, then two cases:

    • Key in tree => reset value

    • Key not in tree => add new node

Property: If distinct keys are inserted into a BST in random order, the expected number of compares for a search /insert is .

Proof: 1-1 correspondence with quicksort partitioning.

  • Floor: Largest key <= to a given key.

  • Ceiling: Smallest key >= to a given key.

  • Rank: How many keys < k

Computing the Floor

  • equals to the key at the root. => The floor of is .

  • is less than the key at the root. => The floor of is in the left subtree.

  • is greater than the key at the root. => The floor of is in the right subtree (if there is any key ); otherwise, it is the key at the root.

Computing the Floor

Deleting the Minimum

  1. Go left until finding a node with a null left link.

  2. Replace that node by its right link.

  3. Update subtree counts.

Delete the minimum

Habbard Deletion

  • 0 children: Delete by setting parent link to null.

    Habbard Deletion 0 children
  • 1 child: Delete by replacing parent link.

    Habbard Deletion 1 child
  • 2 children:

    • Find successor of .

    • Delete the minimum in its 's right subtree .

    • Put in 's spot.

    Habbard Deletion 2 children
import java.util.NoSuchElementException; import java.util.Queue; import java.util.LinkedList; public class BST<Key extends Comparable<Key>, Value> { private Node root; private class Node { private final Key key; private Value val; private Node left, right; private int size; public Node(Key key, Value val, int size) { this.key = key; this.val = val; this.size = size; } } public BST() { } public boolean isEmpty() { return size() == 0; } public int size() { return size(root); } private int size(Node x) { if (x == null) return 0; else return x.size; } public boolean contains(Key key) { if (key == null) throw new IllegalArgumentException("argument to contains() is null"); return get(key) != null; } public Value get(Key key) { return get(root, key); } private Value get(Node x, Key key) { if (key == null) throw new IllegalArgumentException("calls get() with a null key"); if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp < 0) return get(x.left, key); else if (cmp > 0) return get(x.right, key); else return x.val; } public void put(Key key, Value val) { if (key == null) throw new IllegalArgumentException("calls put() with a null key"); if (val == null) { delete(key); return; } root = put(root, key, val); assert check(); } private Node put(Node x, Key key, Value val) { if (x == null) return new Node(key, val, 1); int cmp = key.compareTo(x.key); if (cmp < 0) x.left = put(x.left, key, val); else if (cmp > 0) x.right = put(x.right, key, val); else x.val = val; x.size = 1 + size(x.left) + size(x.right); return x; } public void deleteMin() { if (isEmpty()) throw new NoSuchElementException("Symbol table underflow"); root = deleteMin(root); assert check(); } private Node deleteMin(Node x) { if (x.left == null) return x.right; x.left = deleteMin(x.left); x.size = size(x.left) + size(x.right) + 1; return x; } public void deleteMax() { if (isEmpty()) throw new NoSuchElementException("Symbol table underflow"); root = deleteMax(root); assert check(); } private Node deleteMax(Node x) { if (x.right == null) return x.left; x.right = deleteMax(x.right); x.size = size(x.left) + size(x.right) + 1; return x; } public void delete(Key key) { if (key == null) throw new IllegalArgumentException("calls delete() with a null key"); root = delete(root, key); assert check(); } private Node delete(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp < 0) x.left = delete(x.left, key); else if (cmp > 0) x.right = delete(x.right, key); else { if (x.right == null) return x.left; if (x.left == null) return x.right; Node t = x; x = min(t.right); x.right = deleteMin(t.right); x.left = t.left; } x.size = size(x.left) + size(x.right) + 1; return x; } public Key min() { if (isEmpty()) throw new NoSuchElementException("calls min() with empty symbol table"); return min(root).key; } private Node min(Node x) { if (x.left == null) return x; else return min(x.left); } public Key max() { if (isEmpty()) throw new NoSuchElementException("calls max() with empty symbol table"); return max(root).key; } private Node max(Node x) { if (x.right == null) return x; else return max(x.right); } public Key floor(Key key) { if (key == null) throw new IllegalArgumentException("argument to floor() is null"); if (isEmpty()) throw new NoSuchElementException("calls floor() with empty symbol table"); Node x = floor(root, key); if (x == null) throw new NoSuchElementException("argument to floor() is too small"); else return x.key; } private Node floor(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp == 0) return x; if (cmp < 0) return floor(x.left, key); Node t = floor(x.right, key); if (t != null) return t; else return x; } public Key floor2(Key key) { Key x = floor2(root, key, null); if (x == null) throw new NoSuchElementException("argument to floor() is too small"); else return x; } private Key floor2(Node x, Key key, Key best) { if (x == null) return best; int cmp = key.compareTo(x.key); if (cmp < 0) return floor2(x.left, key, best); else if (cmp > 0) return floor2(x.right, key, x.key); else return x.key; } public Key ceiling(Key key) { if (key == null) throw new IllegalArgumentException("argument to ceiling() is null"); if (isEmpty()) throw new NoSuchElementException("calls ceiling() with empty symbol table"); Node x = ceiling(root, key); if (x == null) throw new NoSuchElementException("argument to ceiling() is too large"); else return x.key; } private Node ceiling(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp == 0) return x; if (cmp < 0) { Node t = ceiling(x.left, key); if (t != null) return t; else return x; } return ceiling(x.right, key); } public Key select(int rank) { if (rank < 0 || rank >= size()) { throw new IllegalArgumentException("argument to select() is invalid: " + rank); } return select(root, rank); } private Key select(Node x, int rank) { if (x == null) return null; int leftSize = size(x.left); if (leftSize > rank) return select(x.left, rank); else if (leftSize < rank) return select(x.right, rank - leftSize - 1); else return x.key; } public int rank(Key key) { if (key == null) throw new IllegalArgumentException("argument to rank() is null"); return rank(key, root); } private int rank(Key key, Node x) { if (x == null) return 0; int cmp = key.compareTo(x.key); if (cmp < 0) return rank(key, x.left); else if (cmp > 0) return 1 + size(x.left) + rank(key, x.right); else return size(x.left); } public Iterable<Key> keys() { if (isEmpty()) return new LinkedList<>(); return keys(min(), max()); } public Iterable<Key> keys(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to keys() is null"); if (hi == null) throw new IllegalArgumentException("second argument to keys() is null"); Queue<Key> queue = new LinkedList<>(); keys(root, queue, lo, hi); return queue; } private void keys(Node x, Queue<Key> queue, Key lo, Key hi) { if (x == null) return; int cmplo = lo.compareTo(x.key); int cmphi = hi.compareTo(x.key); if (cmplo < 0) keys(x.left, queue, lo, hi); if (cmplo <= 0 && cmphi >= 0) queue.add(x.key); if (cmphi > 0) keys(x.right, queue, lo, hi); } public int size(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to size() is null"); if (hi == null) throw new IllegalArgumentException("second argument to size() is null"); if (lo.compareTo(hi) > 0) return 0; if (contains(hi)) return rank(hi) - rank(lo) + 1; else return rank(hi) - rank(lo); } public int height() { return height(root); } private int height(Node x) { if (x == null) return -1; return 1 + Math.max(height(x.left), height(x.right)); } public Iterable<Key> levelOrder() { Queue<Key> keys = new LinkedList<>(); Queue<Node> queue = new LinkedList<>(); queue.add(root); while (!queue.isEmpty()) { Node x = queue.remove(); if (x == null) continue; keys.add(x.key); queue.add(x.left); queue.add(x.right); } return keys; } private boolean check() { if (!isBST()) System.out.println("Not in symmetric order"); if (!isSizeConsistent()) System.out.println("Subtree counts not consistent"); if (!isRankConsistent()) System.out.println("Ranks not consistent"); return isBST() && isSizeConsistent() && isRankConsistent(); } private boolean isBST() { return isBST(root, null, null); } private boolean isBST(Node x, Key min, Key max) { if (x == null) return true; if (min != null && x.key.compareTo(min) <= 0) return false; if (max != null && x.key.compareTo(max) >= 0) return false; return isBST(x.left, min, x.key) && isBST(x.right, x.key, max); } private boolean isSizeConsistent() { return isSizeConsistent(root); } private boolean isSizeConsistent(Node x) { if (x == null) return true; if (x.size != size(x.left) + size(x.right) + 1) return false; return isSizeConsistent(x.left) && isSizeConsistent(x.right); } private boolean isRankConsistent() { for (int i = 0; i < size(); i++) if (i != rank(select(i))) return false; for (Key key : keys()) if (key.compareTo(select(rank(key))) != 0) return false; return true; }
#include <iostream> #include <queue> #include <stdexcept> #include <utility> template <typename Key, typename Value> class BST { private: struct Node { Key key; Value val; Node *left, *right; int size; Node(Key key, const Value& val, const int size) : key(std::move(key)), val(val), left(nullptr), right(nullptr), size(size) {} }; Node* root; static int size(Node* x) { if (x == nullptr) return 0; return x->size; } Value get(Node* x, const Key& key) const { if (x == nullptr) return Value(); // Return default value for Value type if (key < x->key) return get(x->left, key); if (key > x->key) return get(x->right, key); return x->val; } Node* put(Node* x, const Key& key, const Value& val) { if (x == nullptr) return new Node(key, val, 1); if (key < x->key) x->left = put(x->left, key, val); else if (key > x->key) x->right = put(x->right, key, val); else x->val = val; x->size = 1 + size(x->left) + size(x->right); return x; } Node* deleteMin(Node* x) { if (x->left == nullptr) { Node* temp = x->right; delete x; return temp; } x->left = deleteMin(x->left); x->size = size(x->left) + size(x->right) + 1; return x; } Node* deleteMax(Node* x) { if (x->right == nullptr) { Node* temp = x->left; delete x; return temp; } x->right = deleteMax(x->right); x->size = size(x->left) + size(x->right) + 1; return x; } Node* deleteKey(Node* x, const Key& key) { if (x == nullptr) return nullptr; if (key < x->key) x->left = deleteKey(x->left, key); else if (key > x->key) x->right = deleteKey(x->right, key); else { if (x->right == nullptr) return x->left; if (x->left == nullptr) return x->right; Node* t = x; x = min(t->right); x->right = deleteMin(t->right); x->left = t->left; } x->size = size(x->left) + size(x->right) + 1; return x; } Node* min(Node* x) const { if (x->left == nullptr) return x; return min(x->left); } Node* max(Node* x) const { if (x->right == nullptr) return x; return max(x->right); } Node* floor(Node* x, const Key& key) const { if (x == nullptr) return nullptr; if (key < x->key) return floor(x->left, key); if (key > x->key) { Node* t = floor(x->right, key); if (t != nullptr) return t; return x; } return x; } Node* ceiling(Node* x, const Key& key) const { if (x == nullptr) return nullptr; if (key > x->key) return ceiling(x->right, key); if (key < x->key) { Node* t = ceiling(x->left, key); if (t != nullptr) return t; return x; } return x; } Key select(Node* x, int rank) const { if (x == nullptr) return Key(); // Return default value for Key type int leftSize = size(x->left); if (leftSize > rank) return select(x->left, rank); else if (leftSize < rank) return select(x->right, rank - leftSize - 1); else return x->key; } int rank(const Key& key, Node* x) const { if (x == nullptr) return 0; if (key < x->key) return rank(key, x->left); else if (key > x->key) return 1 + size(x->left) + rank(key, x->right); else return size(x->left); } void keys(Node* x, std::queue<Key>& queue, const Key& lo, const Key& hi) const { if (x == nullptr) return; if (lo < x->key) keys(x->left, queue, lo, hi); if (lo <= x->key && x->key <= hi) queue.push(x->key); if (hi > x->key) keys(x->right, queue, lo, hi); } int height(Node* x) const { if (x == nullptr) return -1; return 1 + std::max(height(x->left), height(x->right)); } public: BST() : root(nullptr) {} ~BST() { destroy(root); } void destroy(Node* node) { if (node == nullptr) return; destroy(node->left); destroy(node->right); delete node; } [[nodiscard]] bool isEmpty() const { return size() == 0; } [[nodiscard]] int size() const { return size(root); } [[nodiscard]] bool contains(const Key& key) const { return get(key) != Value(); // Compare with default value } [[nodiscard]] Value get(const Key& key) const { return get(root, key); } void put(const Key& key, const Value& val) { root = put(root, key, val); } void deleteMin() { if (isEmpty()) throw std::runtime_error("Symbol table underflow"); root = deleteMin(root); } void deleteMax() { if (isEmpty()) throw std::runtime_error("Symbol table underflow"); root = deleteMax(root); } void deleteKey(const Key& key) { root = deleteKey(root, key); } [[nodiscard]] Key min() const { if (isEmpty()) throw std::runtime_error("calls min() with empty symbol table"); return min(root)->key; } [[nodiscard]] Key max() const { if (isEmpty()) throw std::runtime_error("calls max() with empty symbol table"); return max(root)->key; } [[nodiscard]] Key floor(const Key& key) const { if (isEmpty()) throw std::runtime_error("calls floor() with empty symbol table"); Node* x = floor(root, key); if (x == nullptr) throw std::runtime_error("argument to floor() is too small"); else return x->key; } [[nodiscard]] Key ceiling(const Key& key) const { if (isEmpty()) throw std::runtime_error("calls ceiling() with empty symbol table"); Node* x = ceiling(root, key); if (x == nullptr) throw std::runtime_error("argument to ceiling() is too large"); else return x->key; } [[nodiscard]] Key select(int rank) const { if (rank < 0 || rank >= size()) { throw std::runtime_error("argument to select() is invalid: " + std::to_string(rank)); } return select(root, rank); } [[nodiscard]] int rank(const Key& key) const { return rank(key, root); } [[nodiscard]] std::queue<Key> keys() const { if (isEmpty()) return std::queue<Key>(); return keys(min(), max()); } [[nodiscard]] std::queue<Key> keys(const Key& lo, const Key& hi) const { std::queue<Key> queue; keys(root, queue, lo, hi); return queue; } [[nodiscard]] int size(const Key& lo, const Key& hi) const { if (lo > hi) return 0; if (contains(hi)) return rank(hi) - rank(lo) + 1; else return rank(hi) - rank(lo); } [[nodiscard]] int height() const { return height(root); } [[nodiscard]] std::queue<Key> levelOrder() const { std::queue<Key> keys; std::queue<Node*> queue; queue.push(root); while (!queue.empty()) { Node* x = queue.front(); queue.pop(); if (x == nullptr) continue; keys.push(x->key); queue.push(x->left); queue.push(x->right); } return keys; } };
class Node: def __init__(self, key, val, size): self.key = key self.val = val self.left = None self.right = None self.size = size class BST: def __init__(self): self.root = None def isEmpty(self): return self.size() == 0 def size(self): return self._size(self.root) def _size(self, x): if x is None: return 0 else: return x.size def contains(self, key): if key is None: raise ValueError("argument to contains() is None") return self.get(key) is not None def get(self, key): return self._get(self.root, key) def _get(self, x, key): if key is None: raise ValueError("calls get() with a None key") if x is None: return None if key < x.key: return self._get(x.left, key) elif key > x.key: return self._get(x.right, key) else: return x.val def put(self, key, val): if key is None: raise ValueError("calls put() with a None key") if val is None: self.delete(key) return self.root = self._put(self.root, key, val) assert self._check() def _put(self, x, key, val): if x is None: return Node(key, val, 1) if key < x.key: x.left = self._put(x.left, key, val) elif key > x.key: x.right = self._put(x.right, key, val) else: x.val = val x.size = 1 + self._size(x.left) + self._size(x.right) return x def deleteMin(self): if self.isEmpty(): raise IndexError("Symbol table underflow") self.root = self._deleteMin(self.root) assert self._check() def _deleteMin(self, x): if x.left is None: return x.right x.left = self._deleteMin(x.left) x.size = self._size(x.left) + self._size(x.right) + 1 return x def deleteMax(self): if self.isEmpty(): raise IndexError("Symbol table underflow") self.root = self._deleteMax(self.root) assert self._check() def _deleteMax(self, x): if x.right is None: return x.left x.right = self._deleteMax(x.right) x.size = self._size(x.left) + self._size(x.right) + 1 return x def delete(self, key): if key is None: raise ValueError("calls delete() with a None key") self.root = self._delete(self.root, key) assert self._check() def _delete(self, x, key): if x is None: return None if key < x.key: x.left = self._delete(x.left, key) elif key > x.key: x.right = self._delete(x.right, key) else: if x.right is None: return x.left if x.left is None: return x.right t = x x = self.min(t.right) x.right = self._deleteMin(t.right) x.left = t.left x.size = self._size(x.left) + self._size(x.right) + 1 return x def min(self): if self.isEmpty(): raise IndexError("calls min() with empty symbol table") return self._min(self.root).key def _min(self, x): if x.left is None: return x else: return self._min(x.left) def max(self): if self.isEmpty(): raise IndexError("calls max() with empty symbol table") return self._max(self.root).key def _max(self, x): if x.right is None: return x else: return self._max(x.right) def floor(self, key): if key is None: raise ValueError("argument to floor() is None") if self.isEmpty(): raise IndexError("calls floor() with empty symbol table") x = self._floor(self.root, key) if x is None: raise IndexError("argument to floor() is too small") else: return x.key def _floor(self, x, key): if x is None: return None if key == x.key: return x if key < x.key: return self._floor(x.left, key) t = self._floor(x.right, key) if t is not None: return t else: return x def floor2(self, key): x = self._floor2(self.root, key, None) if x is None: raise IndexError("argument to floor() is too small") else: return x def _floor2(self, x, key, best): if x is None: return best if key < x.key: return self._floor2(x.left, key, best) elif key > x.key: return self._floor2(x.right, key, x.key) else: return x.key def ceiling(self, key): if key is None: raise ValueError("argument to ceiling() is None") if self.isEmpty(): raise IndexError("calls ceiling() with empty symbol table") x = self._ceiling(self.root, key) if x is None: raise IndexError("argument to ceiling() is too large") else: return x.key def _ceiling(self, x, key): if x is None: return None if key == x.key: return x if key < x.key: t = self._ceiling(x.left, key) if t is not None: return t else: return x return self._ceiling(x.right, key) def select(self, rank): if rank < 0 or rank >= self.size(): raise ValueError("argument to select() is invalid: " + str(rank)) return self._select(self.root, rank) def _select(self, x, rank): if x is None: return None leftSize = self._size(x.left) if leftSize > rank: return self._select(x.left, rank) elif leftSize < rank: return self._select(x.right, rank - leftSize - 1) else: return x.key def rank(self, key): if key is None: raise ValueError("argument to rank() is None") return self._rank(key, self.root) def _rank(self, key, x): if x is None: return 0 if key < x.key: return self._rank(key, x.left) elif key > x.key: return 1 + self._size(x.left) + self._rank(key, x.right) else: return self._size(x.left) def keys(self): if self.isEmpty(): return [] return self.keysInRange(self.min(), self.max()) def keysInRange(self, lo, hi): if lo is None: raise ValueError("first argument to keys() is None") if hi is None: raise ValueError("second argument to keys() is None") queue = [] self._keys(self.root, queue, lo, hi) return queue def _keys(self, x, queue, lo, hi): if x is None: return if lo < x.key: self._keys(x.left, queue, lo, hi) if lo <= x.key <= hi: queue.append(x.key) if hi > x.key: self._keys(x.right, queue, lo, hi) def sizeInRange(self, lo, hi): if lo is None: raise ValueError("first argument to size() is None") if hi is None: raise ValueError("second argument to size() is None") if lo > hi: return 0 if self.contains(hi): return self.rank(hi) - self.rank(lo) + 1 else: return self.rank(hi) - self.rank(lo) def height(self): return self._height(self.root) def _height(self, x): if x is None: return -1 return 1 + max(self._height(x.left), self._height(x.right)) def levelOrder(self): keys = [] queue = [self.root] # Using a list as a queue while queue: x = queue.pop(0) # Dequeue from the front if x is None: continue keys.append(x.key) queue.append(x.left) queue.append(x.right) return keys def _check(self): if not self._isBST(): print("Not in symmetric order") if not self._isSizeConsistent(): print("Subtree counts not consistent") if not self._isRankConsistent(): print("Ranks not consistent") return self._isBST() and self._isSizeConsistent() and self._isRankConsistent() def _isBST(self): return self._isBSTHelper(self.root, None, None) def _isBSTHelper(self, x, minKey, maxKey): if x is None: return True if minKey is not None and x.key <= minKey: return False if maxKey is not None and x.key >= maxKey: return False return self._isBSTHelper(x.left, minKey, x.key) and self._isBSTHelper(x.right, x.key, maxKey) def _isSizeConsistent(self): return self._isSizeConsistentHelper(self.root) def _isSizeConsistentHelper(self, x): if x is None: return True if x.size != self._size(x.left) + self._size(x.right) + 1: return False return self._isSizeConsistentHelper(x.left) and self._isSizeConsistentHelper(x.right) def _isRankConsistent(self): for i in range(self.size()): if i != self.rank(self.select(i)): return False for key in self.keys(): if key != self.select(self.rank(key)): return False return True

9.4 Traversal

To traverse binary trees with depth-first search, execute the following three operations in a certain order:

  • N: Visit the current node.

  • L: Recursively traverse the current node's left subtree.

  • R: Recursively traverse the current node's right subtree.

Three types of traversal

  1. Pre-order => NLR

  2. Post-order => LRN

  3. In-order => LNR

Traversal

Depth-first traversal (dotted path) of a binary tree:

  1. Pre-order (node visited at position red):

    F, B, A, D, C, E, G, I, H;

  2. In-order (node visited at position green):

    A, B, C, D, E, F, G, H, I;

  3. Post-order (node visited at position blue):

    A, C, E, D, B, H, I, G, F.

Level Order (breadth-first traversal) : Visit all the nodes of a tree data structure level by level.

Level Order Traversal

  1. Start at the root node.

  2. Visit all the nodes at the current level.

  3. Move to the next level, repeat steps 2 and 3 until all levels of the tree have been visited.

10 Balanced Search Trees

Implementation

Worst-Case Cost (after inserts)

Average Case (after random inserts)

Ordered Iteration?

Key Interface

Search

Insert

Delete

Search Hit

Insert

Delete

Sequential Search (unordered list)

no

equals()

Binary Search (ordered list)

yes

compareTo()

BST

?

yes

compareTo()

2-3 Tree

yes

compareTo()

Red-Black BST

yes

compareTo()

10.1 2-3 Trees

2-3 Tree

  • Allow 1 or 2 keys per node.

  • 2-node: one key, two children.

  • 3-node: two keys, three children.

2-3 Tree

Searching in 2-3 Tree

  1. Compare search key against keys in node.

  2. Find interval containing search key.

  3. Follow associated key (recursively).

Inserting into a 2-node At Bottom

  1. Search for key, as usual.

  2. Replace 2-node with 3-node.

Inserting into a 3-node At Bottom

  1. Add new key to 3-node to create a temporary 4-node.

  2. Move middle key in 4-node into a parent.

  3. Repeat up the tree, as necessary.

  4. If you reach the root and it's a 4-node, split it into three 2-nodes.

Properties

  • Maintain symmetric order and perfect balance: Every path from root to null link has same length.

    Proof

  • Worst case: => all 2-nodes

  • Best case: => all 3-nodes

  • Between 12 and 20 for a million nodes.

  • Between 18 and 30 for a billion nodes.

  • Guaranteed logarithmic performance for search and insert.

10.2 Red-Black BSTs

10.2.1 Left-Leaning Red-Black BSTs

  1. Definition 1:

    • Represent 2–3 tree as a BST.

    • Use "internal" left-leaning links as "glue" for 3–nodes.

  2. Definition 2: A BST such that:

    • No node has two red links connected to it.

    • Every path from root to null link has the same number of black links.

    • Red links lean left.

Red-Black BST

10.2.2 Elementary Red-Black BST Operations

  1. Left rotation: Orient a (temporarily) right-leaning red link to lean left.

    Left Rotation
  2. Right rotation: Orient a left-leaning red link to (temporarily) lean right.

    Right Rotation
  3. Color flip: Recolor to split a (temporary) 4-node.

    Color Flip

10.2.3 Red-Black BST Operations

Case 1: Insert into a 2-node at the bottom | Insert into a tree with exactly 1 node

  1. Do standard BST insert; color new link red.

  2. If new red link is a right link, rotate left.

Case 2: Insert into a 3-node at the bottom | Insert into a tree with exactly 2 nodes.

  1. Do standard BST insert; color new link red.

  2. Rotate to balance the 4-node (if needed).

  3. Flip colors to pass red link up one level.

  4. Rotate to make lean left (if needed).

  5. Repeat case 1 or case 2 up the tree (if needed).

Insert into a 3-node at the bottom

Insertion for Red-Black BSTs

  • Right child red, left child black: rotate left.

  • Left child, left-left grandchild red: rotate right.

  • Both children red: flip colors.

10.2.4 Red-Black BST Implementations

import java.util.LinkedList; import java.util.NoSuchElementException; import java.util.Queue; public class RedBlackBST<Key extends Comparable<Key>, Value> { private static final boolean RED = true; private static final boolean BLACK = false; private Node root; private class Node { private Key key; private Value val; private Node left, right; private boolean color; private int size; public Node(Key key, Value val, boolean color, int size) { this.key = key; this.val = val; this.color=color; this.size = size; } } public RedBlackBST() { } private boolean isRed(Node x) { if (x == null) return false; return x.color == RED; } private int size(Node x) { if (x == null) return 0; return x.size; } public int size() { return size(root); } public boolean isEmpty() { return root == null; } public Value get(Key key) { if (key == null) throw new IllegalArgumentException("argument to get() is null"); return get(root, key); } private Value get(Node x, Key key) { while (x != null) { int cmp = key.compareTo(x.key); if (cmp < 0) x = x.left; else if (cmp > 0) x = x.right; else return x.val; } return null; } public boolean contains(Key key) { return get(key) != null; } public void put(Key key, Value val) { if (key == null) throw new IllegalArgumentException("first argument to put() is null"); if (val == null) { delete(key); return; } root = put(root, key, val); root.color=BLACK; } private Node put(Node h, Key key, Value val) { if (h == null) return new Node(key, val, RED, 1); int cmp = key.compareTo(h.key); if (cmp < 0) h.left = put(h.left, key, val); else if (cmp > 0) h.right = put(h.right, key, val); else h.val = val; if (isRed(h.right) && !isRed(h.left)) h = rotateLeft(h); if (isRed(h.left) && isRed(h.left.left)) h = rotateRight(h); if (isRed(h.left) && isRed(h.right)) flipColors(h); h.size = size(h.left) + size(h.right) + 1; return h; } public void deleteMin() { if (isEmpty()) throw new NoSuchElementException("BST underflow"); if (!isRed(root.left) && !isRed(root.right)) root.color=RED; root = deleteMin(root); if (!isEmpty()) root.color=BLACK; } private Node deleteMin(Node h) { if (h.left == null) return null; if (!isRed(h.left) && !isRed(h.left.left)) h = moveRedLeft(h); h.left = deleteMin(h.left); return balance(h); } public void deleteMax() { if (isEmpty()) throw new NoSuchElementException("BST underflow"); if (!isRed(root.left) && !isRed(root.right)) root.color=RED; root = deleteMax(root); if (!isEmpty()) root.color=BLACK; } private Node deleteMax(Node h) { if (isRed(h.left)) h = rotateRight(h); if (h.right == null) return null; if (!isRed(h.right) && !isRed(h.right.left)) h = moveRedRight(h); h.right = deleteMax(h.right); return balance(h); } public void delete(Key key) { if (key == null) throw new IllegalArgumentException("argument to delete() is null"); if (!contains(key)) return; if (!isRed(root.left) && !isRed(root.right)) root.color=RED; root = delete(root, key); if (!isEmpty()) root.color=BLACK; } private Node delete(Node h, Key key) { if (key.compareTo(h.key) < 0) { if (!isRed(h.left) && !isRed(h.left.left)) h = moveRedLeft(h); h.left = delete(h.left, key); } else { if (isRed(h.left)) h = rotateRight(h); if (key.compareTo(h.key) == 0 && (h.right == null)) return null; if (!isRed(h.right) && !isRed(h.right.left)) h = moveRedRight(h); if (key.compareTo(h.key) == 0) { Node x = min(h.right); h.key = x.key; h.val = x.val; h.right = deleteMin(h.right); } else h.right = delete(h.right, key); } return balance(h); } private Node rotateRight(Node h) { assert (h != null) && isRed(h.left); Node x = h.left; h.left = x.right; x.right = h; x.color=h.color; h.color=RED; x.size = h.size; h.size = size(h.left) + size(h.right) + 1; return x; } private Node rotateLeft(Node h) { assert (h != null) && isRed(h.right); Node x = h.right; h.right = x.left; x.left = h; x.color=h.color; h.color=RED; x.size = h.size; h.size = size(h.left) + size(h.right) + 1; return x; } private void flipColors(Node h) { h.color=!h.color; h.left.color=!h.left.color; h.right.color=!h.right.color; } private Node moveRedLeft(Node h) { flipColors(h); if (isRed(h.right.left)) { h.right = rotateRight(h.right); h = rotateLeft(h); flipColors(h); } return h; } private Node moveRedRight(Node h) { flipColors(h); if (isRed(h.left.left)) { h = rotateRight(h); flipColors(h); } return h; } private Node balance(Node h) { if (isRed(h.right) && !isRed(h.left)) h = rotateLeft(h); if (isRed(h.left) && isRed(h.left.left)) h = rotateRight(h); if (isRed(h.left) && isRed(h.right)) flipColors(h); h.size = size(h.left) + size(h.right) + 1; return h; } public int height() { return height(root); } private int height(Node x) { if (x == null) return -1; return 1 + Math.max(height(x.left), height(x.right)); } public Key min() { if (isEmpty()) throw new NoSuchElementException("calls min() with empty symbol table"); return min(root).key; } private Node min(Node x) { if (x.left == null) return x; else return min(x.left); } public Key max() { if (isEmpty()) throw new NoSuchElementException("calls max() with empty symbol table"); return max(root).key; } private Node max(Node x) { if (x.right == null) return x; else return max(x.right); } public Key floor(Key key) { if (key == null) throw new IllegalArgumentException("argument to floor() is null"); if (isEmpty()) throw new NoSuchElementException("calls floor() with empty symbol table"); Node x = floor(root, key); if (x == null) throw new NoSuchElementException("argument to floor() is too small"); else return x.key; } private Node floor(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp == 0) return x; if (cmp < 0) return floor(x.left, key); Node t = floor(x.right, key); if (t != null) return t; else return x; } public Key ceiling(Key key) { if (key == null) throw new IllegalArgumentException("argument to ceiling() is null"); if (isEmpty()) throw new NoSuchElementException("calls ceiling() with empty symbol table"); Node x = ceiling(root, key); if (x == null) throw new NoSuchElementException("argument to ceiling() is too large"); else return x.key; } private Node ceiling(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp == 0) return x; if (cmp > 0) return ceiling(x.right, key); Node t = ceiling(x.left, key); if (t != null) return t; else return x; } public Key select(int rank) { if (rank < 0 || rank >= size()) { throw new IllegalArgumentException("argument to select() is invalid: " + rank); } return select(root, rank); } private Key select(Node x, int rank) { if (x == null) return null; int leftSize = size(x.left); if (leftSize > rank) return select(x.left, rank); else if (leftSize < rank) return select(x.right, rank - leftSize - 1); else return x.key; } public int rank(Key key) { if (key == null) throw new IllegalArgumentException("argument to rank() is null"); return rank(key, root); } private int rank(Key key, Node x) { if (x == null) return 0; int cmp = key.compareTo(x.key); if (cmp < 0) return rank(key, x.left); else if (cmp > 0) return 1 + size(x.left) + rank(key, x.right); else return size(x.left); } public Iterable<Key> keys() { if (isEmpty()) return new LinkedList<>(); return keys(min(), max()); } public Iterable<Key> keys(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to keys() is null"); if (hi == null) throw new IllegalArgumentException("second argument to keys() is null"); Queue<Key> queue = new LinkedList<>(); keys(root, queue, lo, hi); return queue; } private void keys(Node x, Queue<Key> queue, Key lo, Key hi) { if (x == null) return; int cmplo = lo.compareTo(x.key); int cmphi = hi.compareTo(x.key); if (cmplo < 0) keys(x.left, queue, lo, hi); if (cmplo <= 0 && cmphi >= 0) queue.add(x.key); if (cmphi > 0) keys(x.right, queue, lo, hi); } public int size(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to size() is null"); if (hi == null) throw new IllegalArgumentException("second argument to size() is null"); if (lo.compareTo(hi) > 0) return 0; if (contains(hi)) return rank(hi) - rank(lo) + 1; else return rank(hi) - rank(lo); } private boolean check() { if (!isBST()) System.out.println("Not in symmetric order"); if (!isSizeConsistent()) System.out.println("Subtree counts not consistent"); if (!isRankConsistent()) System.out.println("Ranks not consistent"); if (!is23()) System.out.println("Not a 2-3 tree"); if (!isBalanced()) System.out.println("Not balanced"); return isBST() && isSizeConsistent() && isRankConsistent() && is23() && isBalanced(); } private boolean isBST() { return isBST(root, null, null); } private boolean isBST(Node x, Key min, Key max) { if (x == null) return true; if (min != null && x.key.compareTo(min) <= 0) return false; if (max != null && x.key.compareTo(max) >= 0) return false; return isBST(x.left, min, x.key) && isBST(x.right, x.key, max); } private boolean isSizeConsistent() { return isSizeConsistent(root); } private boolean isSizeConsistent(Node x) { if (x == null) return true; if (x.size != size(x.left) + size(x.right) + 1) return false; return isSizeConsistent(x.left) && isSizeConsistent(x.right); } private boolean isRankConsistent() { for (int i = 0; i < size(); i++) if (i != rank(select(i))) return false; for (Key key : keys()) if (key.compareTo(select(rank(key))) != 0) return false; return true; } private boolean is23() { return is23(root); } private boolean is23(Node x) { if (x == null) return true; if (isRed(x.right)) return false; if (x != root && isRed(x) && isRed(x.left)) return false; return is23(x.left) && is23(x.right); } private boolean isBalanced() { int black = 0; Node x = root; while (x != null) { if (!isRed(x)) black++; x = x.left; } return isBalanced(root, black); } private boolean isBalanced(Node x, int black) { if (x == null) return black == 0; if (!isRed(x)) black--; return isBalanced(x.left, black) && isBalanced(x.right, black); } }
#ifndef REDBLACKBST_H #define REDBLACKBST_H #include <iostream> #include <queue> #include <stdexcept> #include <cassert> template <typename Key, typename Value> class RedBlackBST { private: static constexpr bool RED = true; static constexpr bool BLACK = false; struct Node { Key key; Value val; Node *left, *right; bool color; int size; Node(const Key& key, const Value& val, bool color, int size) : key(key), val(val), left(nullptr), right(nullptr), color(color), size(size) {} }; Node* root; static bool isRed(Node* x) { if (x == nullptr) return false; return x->color == RED; } static int size(Node* x) { if (x == nullptr) return 0; return x->size; } Node* put(Node* h, const Key& key, const Value& val) { if (h == nullptr) return new Node(key, val, RED, 1); if (key < h->key) h->left = put(h->left, key, val); else if (key > h->key) h->right = put(h->right, key, val); else h->val = val; if (isRed(h->right) && !isRed(h->left)) h = rotateLeft(h); if (isRed(h->left) && isRed(h->left->left)) h = rotateRight(h); if (isRed(h->left) && isRed(h->right)) flipColors(h); h->size = size(h->left) + size(h->right) + 1; return h; } Node* deleteMin(Node* h) { if (h->left == nullptr) return nullptr; if (!isRed(h->left) && !isRed(h->left->left)) h = moveRedLeft(h); h->left = deleteMin(h->left); return balance(h); } Node* deleteMax(Node* h) { if (isRed(h->left)) h = rotateRight(h); if (h->right == nullptr) return nullptr; if (!isRed(h->right) && !isRed(h->right->left)) h = moveRedRight(h); h->right = deleteMax(h->right); return balance(h); } Node* deleteNode(Node* h, const Key& key) { if (key < h->key) { if (!isRed(h->left) && !isRed(h->left->left)) h = moveRedLeft(h); h->left = deleteNode(h->left, key); } else { if (isRed(h->left)) h = rotateRight(h); if (key == h->key && (h->right == nullptr)) return nullptr; if (!isRed(h->right) && !isRed(h->right->left)) h = moveRedRight(h); if (key == h->key) { Node* x = min(h->right); h->key = x->key; h->val = x->val; h->right = deleteMin(h->right); } else h->right = deleteNode(h->right, key); } return balance(h); } Node* rotateRight(Node* h) { assert(h != nullptr && isRed(h->left)); Node* x = h->left; h->left = x->right; x->right = h; x->color=h->color; h->color=RED; x->size = h->size; h->size = size(h->left) + size(h->right) + 1; return x; } Node* rotateLeft(Node* h) { assert(h != nullptr && isRed(h->right)); Node* x = h->right; h->right = x->left; x->left = h; x->color=h->color; h->color=RED; x->size = h->size; h->size = size(h->left) + size(h->right) + 1; return x; } static void flipColors(Node* h) { h->color=!h->color; h->left->color=!h->left->color; h->right->color=!h->right->color; } Node* moveRedLeft(Node* h) { flipColors(h); if (isRed(h->right->left)) { h->right = rotateRight(h->right); h = rotateLeft(h); flipColors(h); } return h; } Node* moveRedRight(Node* h) { flipColors(h); if (isRed(h->left->left)) { h = rotateRight(h); flipColors(h); } return h; } Node* balance(Node* h) { if (isRed(h->right) && !isRed(h->left)) h = rotateLeft(h); if (isRed(h->left) && isRed(h->left->left)) h = rotateRight(h); if (isRed(h->left) && isRed(h->right)) flipColors(h); h->size = size(h->left) + size(h->right) + 1; return h; } Node* min(Node* x) const { if (x->left == nullptr) return x; else return min(x->left); } Node* max(Node* x) const { if (x->right == nullptr) return x; else return max(x->right); } Node* floor(Node* x, const Key& key) const { if (x == nullptr) return nullptr; if (key == x->key) return x; if (key < x->key) return floor(x->left, key); Node* t = floor(x->right, key); if (t != nullptr) return t; else return x; } Node* ceiling(Node* x, const Key& key) const { if (x == nullptr) return nullptr; if (key == x->key) return x; if (key > x->key) return ceiling(x->right, key); Node* t = ceiling(x->left, key); if (t != nullptr) return t; else return x; } Key select(Node* x, const int rank) const { if (x == nullptr) return Key(); int leftSize = size(x->left); if (leftSize > rank) return select(x->left, rank); else if (leftSize < rank) return select(x->right, rank - leftSize - 1); else return x->key; } int rank(const Key& key, Node* x) const { if (x == nullptr) return 0; if (key < x->key) return rank(key, x->left); else if (key > x->key) return 1 + size(x->left) + rank(key, x->right); else return size(x->left); } void keys(Node* x, std::queue<Key>& queue, const Key& lo, const Key& hi) const { if (x == nullptr) return; if (lo < x->key) keys(x->left, queue, lo, hi); if (lo <= x->key && x->key <= hi) queue.push(x->key); if (hi > x->key) keys(x->right, queue, lo, hi); } int height(Node* x) const { if (x == nullptr) return -1; return 1 + std::max(height(x->left), height(x->right)); } public: RedBlackBST() : root(nullptr) {} [[nodiscard]] int size() const { return size(root); } [[nodiscard]] bool isEmpty() const { return root == nullptr; } Value get(const Key& key) const { Node* x = root; while (x != nullptr) { if (key < x->key) x = x->left; else if (key > x->key) x = x->right; else return x->val; } return Value(); } bool contains(const Key& key) const { return get(key) != Value(); } void put(const Key& key, const Value& val) { root = put(root, key, val); root->color=BLACK; } void deleteMin() { if (isEmpty()) throw std::runtime_error("BST underflow"); if (!isRed(root->left) && !isRed(root->right)) root->color=RED; root = deleteMin(root); if (!isEmpty()) root->color=BLACK; } void deleteMax() { if (isEmpty()) throw std::runtime_error("BST underflow"); if (!isRed(root->left) && !isRed(root->right)) root->color=RED; root = deleteMax(root); if (!isEmpty()) root->color=BLACK; } void deleteNode(const Key& key) { if (!contains(key)) return; if (!isRed(root->left) && !isRed(root->right)) root->color=RED; root = deleteNode(root, key); if (!isEmpty()) root->color=BLACK; } [[nodiscard]] int height() const { return height(root); } Key min() const { if (isEmpty()) throw std::runtime_error("calls min() with empty symbol table"); return min(root)->key; } Key max() const { if (isEmpty()) throw std::runtime_error("calls max() with empty symbol table"); return max(root)->key; } Key floor(const Key& key) const { if (isEmpty()) throw std::runtime_error("calls floor() with empty symbol table"); Node* x = floor(root, key); if (x == nullptr) throw std::runtime_error("argument to floor() is too small"); else return x->key; } Key ceiling(const Key& key) const { if (isEmpty()) throw std::runtime_error("calls ceiling() with empty symbol table"); Node* x = ceiling(root, key); if (x == nullptr) throw std::runtime_error("argument to ceiling() is too large"); else return x->key; } Key select(int rank) const { if (rank < 0 || rank >= size()) { throw std::invalid_argument("argument to select() is invalid: " + std::to_string(rank)); } return select(root, rank); } int rank(const Key& key) const { return rank(key, root); } std::queue<Key> keys() const { if (isEmpty()) return std::queue<Key>(); return keys(min(), max()); } std::queue<Key> keys(const Key& lo, const Key& hi) const { if (isEmpty() || lo > hi) return std::queue<Key>(); std::queue<Key> queue; keys(root, queue, lo, hi); return queue; } int size(const Key& lo, const Key& hi) const { if (lo > hi) return 0; if (contains(hi)) return rank(hi) - rank(lo) + 1; else return rank(hi) - rank(lo); } }; #endif // REDBLACKBST_H
class Node: def __init__(self, key, val, color, size): self.key = key self.val = val self.left = None self.right = None self.color = color # True for RED, False for BLACK self.size = size class RedBlackBST: RED = True BLACK = False def __init__(self): self.root = None def is_red(self, x): if x is None: return False return x.color == RedBlackBST.RED def size(self, x): if x is None: return 0 return x.size def __len__(self): return self.size(self.root) def is_empty(self): return self.root is None def get(self, key): x = self.root while x is not None: if key < x.key: x = x.left elif key > x.key: x = x.right else: return x.val return None def __contains__(self, key): return self.get(key) is not None def put(self, key, val): self.root = self._put(self.root, key, val) self.root.color=RedBlackBST.BLACK def _put(self, h, key, val): if h is None: return Node(key, val, RedBlackBST.RED, 1) if key < h.key: h.left = self._put(h.left, key, val) elif key > h.key: h.right = self._put(h.right, key, val) else: h.val = val if self.is_red(h.right) and not self.is_red(h.left): h = self.rotate_left(h) if self.is_red(h.left) and self.is_red(h.left.left): h = self.rotate_right(h) if self.is_red(h.left) and self.is_red(h.right): self.flip_colors(h) h.size = self.size(h.left) + self.size(h.right) + 1 return h def delete_min(self): if self.is_empty(): raise Exception("BST underflow") if not self.is_red(self.root.left) and not self.is_red(self.root.right): self.root.color=RedBlackBST.RED self.root = self._delete_min(self.root) if not self.is_empty(): self.root.color=RedBlackBST.BLACK def _delete_min(self, h): if h.left is None: return None if not self.is_red(h.left) and not self.is_red(h.left.left): h = self.move_red_left(h) h.left = self._delete_min(h.left) return self.balance(h) def delete_max(self): if self.is_empty(): raise Exception("BST underflow") if not self.is_red(self.root.left) and not self.is_red(self.root.right): self.root.color=RedBlackBST.RED self.root = self._delete_max(self.root) if not self.is_empty(): self.root.color=RedBlackBST.BLACK def _delete_max(self, h): if self.is_red(h.left): h = self.rotate_right(h) if h.right is None: return None if not self.is_red(h.right) and not self.is_red(h.right.left): h = self.move_red_right(h) h.right = self._delete_max(h.right) return self.balance(h) def delete(self, key): if key is None: raise Exception("argument to delete() is null") if not self.__contains__(key): return if not self.is_red(self.root.left) and not self.is_red(self.root.right): self.root.color=RedBlackBST.RED self.root = self._delete(self.root, key) if not self.is_empty(): self.root.color=RedBlackBST.BLACK def _delete(self, h, key): if key < h.key: if not self.is_red(h.left) and not self.is_red(h.left.left): h = self.move_red_left(h) h.left = self._delete(h.left, key) else: if self.is_red(h.left): h = self.rotate_right(h) if key == h.key and h.right is None: return None if not self.is_red(h.right) and not self.is_red(h.right.left): h = self.move_red_right(h) if key == h.key: if h.right is not None: x = self._min(h.right) h.key = x.key h.val = x.val h.right = self._delete_min(h.right) else: return h.left else: h.right = self._delete(h.right, key) return self.balance(h) def rotate_right(self, h): assert h is not None and self.is_red(h.left) x = h.left h.left = x.right x.right = h x.color=h.color h.color=RedBlackBST.RED x.size = h.size h.size = self.size(h.left) + self.size(h.right) + 1 return x def rotate_left(self, h): assert h is not None and self.is_red(h.right) x = h.right h.right = x.left x.left = h x.color=h.color h.color=RedBlackBST.RED x.size = h.size h.size = self.size(h.left) + self.size(h.right) + 1 return x def flip_colors(self, h): assert h is not None and h.left is not None and h.right is not None h.color=not h.color h.left.color=not h.left.color h.right.color=not h.right.color def move_red_left(self, h): self.flip_colors(h) if self.is_red(h.right.left): h.right = self.rotate_right(h.right) h = self.rotate_left(h) self.flip_colors(h) return h def move_red_right(self, h): self.flip_colors(h) if self.is_red(h.left.left): h = self.rotate_right(h) self.flip_colors(h) return h def balance(self, h): if self.is_red(h.right) and not self.is_red(h.left): h = self.rotate_left(h) if self.is_red(h.left) and self.is_red(h.left.left): h = self.rotate_right(h) if self.is_red(h.left) and self.is_red(h.right): self.flip_colors(h) h.size = self.size(h.left) + self.size(h.right) + 1 return h def height(self): return self._height(self.root) def _height(self, x): if x is None: return -1 return 1 + max(self._height(x.left), self._height(x.right)) def min(self): if self.is_empty(): raise Exception("calls min() with empty symbol table") return self._min(self.root).key def _min(self, x): if x.left is None: return x else: return self._min(x.left) def max(self): if self.is_empty(): raise Exception("calls max() with empty symbol table") return self._max(self.root).key def _max(self, x): if x.right is None: return x else: return self._max(x.right) def floor(self, key): if key is None: raise Exception("argument to floor() is null") if self.is_empty(): raise Exception("calls floor() with empty symbol table") x = self._floor(self.root, key) if x is None: raise Exception("argument to floor() is too small") else: return x.key def _floor(self, x, key): if x is None: return None if key == x.key: return x if key < x.key: return self._floor(x.left, key) t = self._floor(x.right, key) if t is not None: return t else: return x def ceiling(self, key): if key is None: raise Exception("argument to ceiling() is null") if self.is_empty(): raise Exception("calls ceiling() with empty symbol table") x = self._ceiling(self.root, key) if x is None: raise Exception("argument to ceiling() is too large") else: return x.key def _ceiling(self, x, key): if x is None: return None if key == x.key: return x if key > x.key: return self._ceiling(x.right, key) t = self._ceiling(x.left, key) if t is not None: return t else: return x def select(self, rank): if rank < 0 or rank >= len(self): raise Exception("argument to select() is invalid: " + str(rank)) return self._select(self.root, rank).key def _select(self, x, rank): if x is None: return None left_size = self.size(x.left) if left_size > rank: return self._select(x.left, rank) elif left_size < rank: return self._select(x.right, rank - left_size - 1) else: return x def rank(self, key): if key is None: raise Exception("argument to rank() is null") return self._rank(key, self.root) def _rank(self, key, x): if x is None: return 0 if key < x.key: return self._rank(key, x.left) elif key > x.key: return 1 + self.size(x.left) + self._rank(key, x.right) else: return self.size(x.left) def keys(self): if self.is_empty(): return [] return self.keys_in_range(self.min(), self.max()) def keys_in_range(self, lo, hi): if lo is None: raise Exception("first argument to keys() is null") if hi is None: raise Exception("second argument to keys() is null") queue = [] self._keys_in_range(self.root, queue, lo, hi) return queue def _keys_in_range(self, x, queue, lo, hi): if x is None: return if lo < x.key: self._keys_in_range(x.left, queue, lo, hi) if lo <= x.key <= hi: queue.append(x.key) if hi > x.key: self._keys_in_range(x.right, queue, lo, hi) def size_in_range(self, lo, hi): if lo is None: raise Exception("first argument to size() is null") if hi is None: raise Exception("second argument to size() is null") if lo > hi: return 0 if self.__contains__(hi): return self.rank(hi) - self.rank(lo) + 1 else: return self.rank(hi) - self.rank(lo)

10.2.5 Red-Black BST Properties and Applications

Properties

  1. Height of tree is in the worst case.

    Proof: Every path from root to null link has same of black links. Never two red links in-a-row .

  2. Height of tree is in typical applications.

Applications: Red-black trees are widely used as system symbol tables.

  • Java: java.util.TreeMap, java.util.TreeSet

  • C++ STL: map, multimap, multiset

  • Linux kernel: completely fair scheduler, linux/rbtree.h

  • Emacs: conservative stack scanning

10.3 B-Trees

  1. Background Information:

    • Page: Continuous block of data (e.g., a file or 4,096-byte chunk).

    • Probe: First access to a page (e.g., from disk to memory).

    • Property: Time required for a probe is much higher than time to access data within a page.

    • Goal: Access data using minimum number of probes.

  2. Definition:

    B-tree (Bayer-McCreight, 1972): Generalize 2-3 trees by allowing up to key-link pairs per node.

    • At least 2 key-link pairs at root.

    • At least key-link pairs in other nodes.

    • External nodes contain client keys.

    • Internal nodes contain copies of keys to guide search.

    B-Tree
  3. Property:

    A search or an insertion in a B-tree of order with keys requires between and probes.

    Proof: All internal nodes (besides root) have between and links.

    In practice: Number of probes is at most 4.

    Optimization: Always keep page root in memory.

  4. Applications:

    B-trees (and variants B+ Tree, B * Tree, B# Tree) are widely used for file systems and databases.

    • Windows: NTFS.

    • Mac: HFS, HFS+.

    • Linux: ReiserFS, XFS, Ext3FS, JFS.

    • Databases: ORACLE, DB2, INGRES, SQL, PostgreSQL.

Search in B-Tree

  1. Start at root.

  2. Find interval containing search key.

  3. Follow associated link (recursively).

Search in B-Tree

Insert in B-Tree

  1. Search for new key.

  2. Insert at bottom.

  3. Split nodes with key-link pairs on the way up the tree.

Insert in B-Tree

10.4 AVL Trees

AVL trees maintain height-balance (also called the AVL Property).

  1. Skew of a node: The height of of its right subtree minus that of its left subtree.

    A node is height-blanced if .

    Property: A binary tree with height-balanced nodes has height .

    Proof

  2. Suppose adding or removing leaf from a height-balanced tree results in imbalance, skews still have magnitude .

    Case 1: skew of F is 0 or Case 2: skew of F is 1

    => Perform a left rotation on B.

    Balancing AVL Trees

    Case 3: skew of F is −1

    Perform a right rotation on F, then a left rotation on B

    Balancing AVL Trees

11 Geometric Applications of BSTs

Topic: Intersections among geometric objects.

Applications: CAD, games, movies, virtual reality, databases...

  • Range search: find all key between and .

  • Range count:# of keys between and .

  • Geometric interpretation: Keys are point on a line; find/count points in a given 1d interval.

1d range count

  1. Recursively find all keys in left subtree (if any could fall in range).

  2. Check key in current node.

  3. Recursively find all keys in right subtree (if any could fall in range).

Property: Running time proportinal to

11.2 Line Segment Intersection

Goal: Given horizontal and vertical line segments, find all intersections (all - and -coordinates are distinct).

Sweep-Line Algorithm => Sweep Vertical Lines from Left to Right

  1. -coordinates define events.

  2. -segments (left endpoint): insert - coordiantes into BST.

  3. -segments (right endpoint): remove - coordiantes from BST.

  4. - segment: range search for interval of -endpoints.

Line Segment Intersection

Properties: The sweep-line algorithm takes time proportional to to find all intersections among orthogonal line segments.

Proof:

  • Put -coordinates on a PQ (or sort). =>

  • Insert -coordinates into BST. =>

  • Delete -coordinates from BST. =>

  • Range searches in BST. =>

11.3 Kd-Trees

Goal: 2d orthogonal range search.

Geometric interpretation: Keys are point in the plane; find/count points in a given rectangle.

11.3.1 Grid Implementation

Grid Implementation

  1. Divide space into -by- grid of squares.

  2. Create list of points contained in each square.

  3. Use 2d array to directly index relevant square.

  4. Insert: add to list for corresponding square.

  5. Range search: examine only squares that intersect 2d range query.

Properties:

  • Space:

  • Time: per square examined, on average.

Problems:

  • Clustering: a well-known phenomenon in geometric data.

  • Lists are too long, even though average length is short.

  • Need data structure that adapts gracefully to data.

11.3.2 Space-Partitioning Trees

Space-Partitioning Trees: Use a tree to represent a recursive subdivision of a 2d space.

2d Trees: Recursively divide space into two halfplanes.

Applications: Ray tracing, 2d range search, Flight simulators, N-body simulation, Nearest neighbor search, Accelerate rendering in Doom, etc.

Part 1 2d Trees

Data Structure: BST, but alternate using - and - coordinates as key.

  • Search gives rectangle containing point.

  • Insert further subdivides the plane.

2d tree implementation

Range Search - Find all points in a query axis-aligned rectangle

  1. Check if point in node lies in given rectangle.

  2. Recursively search left/bottom (if any could fall in rectangle).

  3. Recursively search right/top (if any could fall in rectangle).

Properties

  • Typical case:

  • Worst case (assuming tree is balanced):

Nearest Neighbor Search - Find closest point to query point

  1. Check distance from point in node to query point.

  2. Recursively search left/bottom (if it could contain a closer point).

  3. Recursively search right/top (if it could contain a closer point).

  4. Organize method so that it begins by searching for query point.

Properties:

  • Typical case:

  • Worst case (even if tree is balanced):

Part 2 Kd Trees

Kd Tree: Recursively partition -dimensional space into 2 halfspaces.

Implementation: BST, but cycle through dimensions ala 2d trees.

Part 3 N-body Simulation

Goal: Simulate the motion of particles, mutually affected by gravity.

Appel's Algorithm for N-body Simulation

  1. Build 3d-tree with particles as nodes.

  2. Store center-of-mass of subtree in each node.

  3. To compute total force acting on a particle, traverse tree, but stop as soon as distance from particle to subdivision is sufficiently large.

Properties: Running time per step is .

11.4 Interval Search Tree

Create BST, where each node stores an interval .

  • Use left endpoint as BST key.

  • Store max endpoint in subtree rooted at node.

Insertion for Interval Search Tree

  1. Insert into BST, using as the key.

  2. Update max in each node on search path.

Interval Search for Interval Search Tree

  • If interval in node intersects query interval, return it.

  • Else if left subtree is null, go right.

  • Else if max endpoint in left subtree is less than lo, go right.

  • Else go left.

Order of growth of running time for intervals.

Operation

Brute

Interval search tree

Best in theory

Insert interval

Find interval

Delete interval

Find any one interval that intersects

find all interval that intersects

11.5 Rectangle Intersection

Sweep-line Algorithm

  • -coordinates of left and right endpoints define events.

  • Maintain set of rectangles that intersect the sweep line in an interval search tree (using -intervals of rectangle).

  • Left endpoint: interval search for -interval of rectangle; insert -interval.

  • Right endpoint: remove -interval.

Property: Sweep line algorithm takes time proportional to to find intersections among a set of rectangles.

Proof

  • Put -coordinates on a PQ (or sort) =>

  • Insert -intervals into ST =>

  • Delete -intervals from ST =>

  • Interval searches for y-intervals =>

12 Hash Tables

Implementation

Worst-Case Cost (after inserts)

Average Case (after random inserts)

Ordered Iteration?

Key Interface

Search

Insert

Delete

Search Hit

Insert

Delete

Sequential Search (unordered list)

no

equals()

Binary Search (ordered list)

yes

compareTo()

BST

?

yes

compareTo()

2-3 Tree

yes

compareTo()

Red-Black BST

yes

compareTo()

Separate Chaining

no

equalshashCode()

Linear Probing

no

equalshashCode()

12.1 Hash Tables

Definitions

  1. Hashing: Save items in a key-indexed table (index is a function of the key).

  2. Hash function: Method for computing array index from key.

    Issues:

    1. Equality test: Method for checking whether two keys are equal.

    2. Collision resolution: Algorithm and data structure to handle two keys that hash to the same array index.

  3. Hash code: An int between and .

  4. Hash function: An int between and (for use of array index).

public final class StringTest { private final char[] s = "Hello, World!".toCharArray(); public int hash() { int hash = 0; for (int i = 0; i < s.length; i++) { hash = (31 * hash) + s[i]; } return hash; } }

12.2 Collision Solution Ⅰ - Separate Chaining & Variant

12.2.1 Separate Chaining

  1. Hash: map key to integer between and .

  2. Insert: put at front of chain (if not already there).

  3. Search: need to search only chain.

Separate Chaining

Properties

  • Number of probes for search/insert/delete is proportional to .

  • Typical choice: (constant operations)

public class SeparateChainingHashST { private final int M = 97; // number of chains private final Node[] st = new Node[M]; // array of chains private static class Node { private final Object key; private Object val; private final Node next; public Node(Object key, Object val, Node node) { this.key = key; this.val = val; this.next = node; } } private int hash(Object key) { return (key.hashCode() & 0x7fffffff) % M; // no bug } public Object get(Object key) { int i = hash(key); for (Node x = st[i]; x != null; x = x.next) if (key.equals(x.key)) return x.val; return null; } public void put(Object key, Object val) { int i = hash(key); for (Node x = st[i]; x != null; x = x.next) if (key.equals(x.key)) { x.val = val; return; } st[i] = new Node(key, val, st[i]); } }
#include <list> #include <vector> #include <optional> template<typename Key, typename Value> class HashTable { public: explicit HashTable(size_t size) : table(size) {} void insert(Key key, Value value) { size_t hashValue = hashFunction(key); auto& chain = table[hashValue]; for (auto& pair : chain) { if (pair.first == key) { pair.second = value; return; } } chain.emplace_back(key, value); } std::optional<Value> get(Key key) { size_t hashValue = hashFunction(key); auto& chain = table[hashValue]; for (auto& pair : chain) { if (pair.first == key) { return pair.second; } } return {}; } void remove(Key key) { size_t hashValue = hashFunction(key); auto& chain = table[hashValue]; chain.remove_if([key](auto pair) { return pair.first == key; }); } private: std::vector<std::list<std::pair<Key, Value>>> table; size_t hashFunction(Key key) { return key % table.size(); } };

12.2.2 Variant - Two-Probe Hashing

  • Hash to two positions, insert key in shorter of the two chains.

  • Reduces expected length of the longest chain to .

12.3 Collision Solution Ⅱ - Open Addressing

12.3.1 Linear Probing

Open addressing: When a new key collides, find next empty slot, and put it there.

  1. Hash: Map key to integer between and .

  2. Search: Search table index ; if occupied but no match, try , , etc..

  3. Insert: Put at table index if free; if not try , , etc.

Under uniform hashing assumption, the average numbers of probes in a linear probing hash table of size that contains keys is:

  • Search hit:

  • Search miss / insert:

public class LinearProbingHashST<Key, Value> { private final int M = 30001; private final Key[] keys = (Key[]) new Object[M]; private final Value[] vals = (Value[]) new Object[M]; /* Map key to integer i between 0 and M - 1. */ private int hash(Key key) { return (key.hashCode() & 0x7fffffff) % M; } /* Put at table index i if free; if not try i+1, i+2, etc. */ public void put(Key key, Value val) { int i; for (i = hash(key); keys[i] != null; i = (i + 1) % M) { if (keys[i].equals(key)) { vals[i] = val; return; } } keys[i] = key; vals[i] = val; } /* Search table index i; if occupied but not match, try i+1, i+2, etc. */ public Value get(Key key) { for (int i = hash(key); keys[i] != null; i = (i + 1) % M) if (keys[i].equals(key)) return vals[i]; return null; } }
public class LinearProbingHashST<Key, Value> { // must be a power of 2 private static final int INIT_CAPACITY = 4; private int n; // number of key-value pairs in the symbol table private int m; // size of linear probing table private Key[] keys; // the keys private Value[] vals; // the values public LinearProbingHashST() { this(INIT_CAPACITY); } public LinearProbingHashST(int capacity) { m = capacity; n = 0; keys = (Key[]) new Object[m]; vals = (Value[]) new Object[m]; } public int size() { return n; } public boolean isEmpty() { return size() == 0; } public boolean contains(Key key) { if (key == null) throw new IllegalArgumentException("argument to contains() is null"); return get(key) != null; } // hash function for keys - returns value between 0 and m-1 private int hashTextbook(Key key) { return (key.hashCode() & 0x7fffffff) % m; } // hash function for keys - returns value between 0 and m-1 (assumes m is a power of 2) // (from Java 7 implementation, protects against poor quality hashCode() implementations) private int hash(Key key) { int h = key.hashCode(); h ^= (h >>> 20) ^ (h >>> 12) ^ (h >>> 7) ^ (h >>> 4); return h & (m - 1); } // resizes the hash table to the given capacity by re-hashing all of the keys private void resize(int capacity) { LinearProbingHashST<Key, Value> temp = new LinearProbingHashST<Key, Value>(capacity); for (int i = 0; i < m; i++) { if (keys[i] != null) { temp.put(keys[i], vals[i]); } } keys = temp.keys; vals = temp.vals; m = temp.m; } public void put(Key key, Value val) { if (key == null) throw new IllegalArgumentException("first argument to put() is null"); if (val == null) { delete(key); return; } // double table size if 50% full if (n >= m / 2) resize(2 * m); int i; for (i = hash(key); keys[i] != null; i = (i + 1) % m) { if (keys[i].equals(key)) { vals[i] = val; return; } } keys[i] = key; vals[i] = val; n++; } // Returns the value associated with the specified key. public Value get(Key key) { if (key == null) throw new IllegalArgumentException("argument to get() is null"); for (int i = hash(key); keys[i] != null; i = (i + 1) % m) if (keys[i].equals(key)) return vals[i]; return null; } public void delete(Key key) { if (key == null) throw new IllegalArgumentException("argument to delete() is null"); if (!contains(key)) return; // find position i of key int i = hash(key); while (!key.equals(keys[i])) { i = (i + 1) % m; } // delete key and associated value keys[i] = null; vals[i] = null; // rehash all keys in same cluster i = (i + 1) % m; while (keys[i] != null) { // delete keys[i] and vals[i] and reinsert Key keyToRehash = keys[i]; Value valToRehash = vals[i]; keys[i] = null; vals[i] = null; n--; put(keyToRehash, valToRehash); i = (i + 1) % m; } n--; // halves size of array if it's 12.5% full or less if (n > 0 && n <= m / 8) resize(m / 2); assert check(); } private boolean check() { if (m < 2 * n) { System.err.println("Hash table size m = " + m + "; array size n = " + n); return false; } for (int i = 0; i < m; i++) { if (keys[i] == null) continue; else if (get(keys[i]) != vals[i]) { System.err.println("get[" + keys[i] + "] = " + get(keys[i]) + "; vals[i] = " + vals[i]); return false; } } return true; } }
#include <vector> #include <optional> template<typename Key, typename Value> struct HashNode { Key key; Value value; bool occupied; HashNode() : occupied(false) {} HashNode(Key key, Value value) : key(key), value(value), occupied(true) {} }; template<typename Key, typename Value> class HashTable { private: std::vector<HashNode<Key, Value>> table; int tableSize; int hashFunction(Key key) { return key % tableSize; } public: HashTable(int size) : table(size), tableSize(size) {} void insert(Key key, Value value) { int index = hashFunction(key); while (table[index].occupied) { index = (index + 1) % tableSize; } table[index] = HashNode<Key, Value>(key, value); } std::optional<Value> get(Key key) { int index = hashFunction(key); while (table[index].occupied) { if (table[index].key == key) { return table[index].value; } index = (index + 1) % tableSize; } return {}; } void remove(Key key) { int index = hashFunction(key); while (table[index].occupied) { if (table[index].key == key) { table[index].occupied = false; return; } index = (index + 1) % tableSize; } } };

Knuth's Parking Problem

Cars arrive at a one-way street with parking spaces. Each driver tries to park in their own space : If space is taken, try , , etc. What is the mean displacement of the car?

  • Half-full: With cars, mean displacement is .

  • Full: With cars, mean displacement is .

12.3.2 Varaint 1 - Double Hashing

  • Use linear probing, but skip a variable amount, not just 1 each time.

  • Effectively eliminates clustering.

  • Can allow table to become nearly full.

  • More difficult to implement delete.

Insert: Use the hash function to calculate index. If there is a collision, use hash value for "step size" for probing until an empty slot is found. (=> )

12.3.3 Variant 2 - Quadratic Probing

Insert: Use the hash function to calculate index. If there is a collision, probe the index using the following probing sequence:

  • index 1:

  • index 2:

  • index 3:

12.3.4 Variant 3 - Cuckoo Hashing

  • Hash key to two positions; insert key into either position; if occupied, reinsert displaced key into its alternative position (and recur).

  • Constant worst case time for search.

12.3.5 Separate Chaining vs. Linear Probing

Separate Chaining

  • Easier to implement delete.

  • Performance degrades gracefully.

  • Clustering less sensitive to poorly-designed hash function.

Linear Probing

  • Less wasted space.

  • Better cache performance.

12.4 Hash Table vs. Balanced Search Tree

Hash Table

  • Simpler to code.

  • No effective alternative for unordered keys.

  • Faster for simple keys (a few arithmetic ops versus compares).

  • Better system support in Java for strings (e.g., cached hash code).

Balanced Search Tree

  • Stronger performance guarantee.

  • Support for ordered ST operations.

  • Easier to implement compareTo() correctly than equals() and hashCode().

Java systems includes both.

  • Red-black BSTs: java.util.TreeMap, java.util.TreeSet.

  • Hash tables: java.util.HashMap, java.util.IdentityHashMap.

C++ STL includes both.

  • Red-black BSTs: std::set, std::map.

  • Hash tables: std::unordered_map, std::unordered_set.

13 Symbol Table Applications

13.1 Sets

Mathematical Set: A collection of distinct keys.

13.1.1 Sets in Java

  1. HashSet

    1. Implementation: Uses a hash table (specifically, a HashMap internally) for storage.

    2. Features

      • Efficient for adding, removing, and checking for the existence of elements (average time complexity).

      • Does not maintain insertion order.

      • Allows a single null element.

  2. LinkedHashSet

    1. Implementation: Extends HashSet and maintains a doubly linked list to preserve the order of element insertion.

    2. Features

      • Elements are iterated in the order they were added.

      • Slightly slower than HashSet due to the linked list overhead.

  3. TreeSet

    1. Implementation: Uses a red-black tree (a self-balancing binary search tree).

    2. Features:

      • Elements are stored in sorted order (natural order or using a Comparator provided during set creation).

      • Provides efficient retrieval of elements in a sorted range.

      • Slower than HashSet and LinkedHashSet for insertion and removal operations (logarithmic time complexity).

      • Does not allow null elements by default.

import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; import java.util.TreeSet; public class SetExample { public static void main(String[] args) { // HashSet - No order guarantee Set<String> hashSet = new HashSet<>(); hashSet.add("Apple"); hashSet.add("Banana"); hashSet.add("Orange"); System.out.println("HashSet: " + hashSet); // Output may vary in order // LinkedHashSet - Maintains insertion order Set<String> linkedHashSet = new LinkedHashSet<>(); linkedHashSet.add("Apple"); linkedHashSet.add("Banana"); linkedHashSet.add("Orange"); System.out.println("LinkedHashSet: " + linkedHashSet); // Output: [Apple, Banana, Orange] // TreeSet - Sorted order Set<String> treeSet = new TreeSet<>(); treeSet.add("Orange"); treeSet.add("Apple"); treeSet.add("Banana"); System.out.println("TreeSet: " + treeSet); // Output: [Apple, Banana, Orange] } }

13.1.2 Sets in C++

  1. std::set | std::multiset

    1. Implementation: Usually implemented as a self-balancing binary search tree (often a red-black tree).

    2. Features

      • Elements are stored in sorted order (by default, using std::less, which is the less-than operator <)

      • Most operations like insertion, search, deletion, etc., have a time complexity of , where is the number of elements, making it efficient for larger datasets.

  2. std::unordered_set | std::unordered_multiset

    1. Implementation Using a hash table, which prioritizes fast average-case performance for operations like insertion, search, and deletion over maintaining a specific order.

    2. Features

      • Offers O(1) average-case time complexity for insertion, search, and deletion operations.

#include <iostream> #include <set> int main() { std::set<int> uniqueNumbers; uniqueNumbers.insert(3); uniqueNumbers.insert(1); uniqueNumbers.insert(4); uniqueNumbers.insert(1); // Duplicate, won't be added std::cout << "Elements in the set: "; for (int num : uniqueNumbers) { std::cout << num << " "; } // Output: 1 3 4 return 0; }

13.1.3 Sets in Python

For more information, please visit Sets in Python Programming

13.2 Dictionary Clients

import java.util.HashMap; public class Main { public static void main(String[] args) { HashMap<String, Integer> map = new HashMap<>(); map.put("Alice", 25); map.put("Bob", 30); map.put("Charlie", 35); int age = map.get("Alice"); System.out.println("Alice's age: " + age); boolean exists = map.containsKey("Bob"); System.out.println("Is Bob in the map? " + exists); map.remove("Charlie"); System.out.println(map); } }
#include <iostream> #include <map> int main() { std::map<std::string, int> myMap; myMap["apple"] = 1; myMap["banana"] = 2; myMap["cherry"] = 3; std::cout << "The value associated with key 'apple' is: " << myMap["apple"] << std::endl; for (const auto& pair : myMap) { std::cout << "Key: " << pair.first << ", Value: " << pair.second << std::endl; } return 0; }
#include <iostream> #include <unordered_map> int main() { std::unordered_map<std::string, int> myMap; myMap["apple"] = 1; myMap["banana"] = 2; myMap["cherry"] = 3; std::cout << "The value associated with key 'apple' is: " << myMap["apple"] << std::endl; for (const auto& pair : myMap) { std::cout << "Key: " << pair.first << ", Value: " << pair.second << std::endl; } return 0; }
person = { "name": "John", "age": 30, "city": "New York" } print(person["name"]) # Output: John

For more information about dictionaries in Python, please visit Python Programming.

13.3 Indexing Clients

import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.TreeMap; public class InvertedIndexJava { static class Document { int id; String content; Document(int id, String content) { this.id = id; this.content = content; } } static TreeMap<String, List<Integer>> buildInvertedIndex(List<Document> documents) { TreeMap<String, List<Integer>> index = new TreeMap<>(); for (Document doc : documents) { String[] words = doc.content.toLowerCase().split("\\s+"); // Tokenize into words for (String word : words) { index.computeIfAbsent(word, k -> new ArrayList<>()).add(doc.id); } } return index; } public static void main(String[] args) { List<Document> documents = Arrays.asList( new Document(1, "The quick brown fox jumps over the lazy dog"), new Document(2, "A lazy cat sleeps all day long"), new Document(3, "The quick rabbit jumps over the fence") ); TreeMap<String, List<Integer>> invertedIndex = buildInvertedIndex(documents); // Example query: Find documents containing the word "jumps" String searchTerm = "jumps"; if (invertedIndex.containsKey(searchTerm)) { System.out.println("Documents containing '" + searchTerm + "': " + invertedIndex.get(searchTerm)); } else { System.out.println("No documents found containing '" + searchTerm + "'"); } } }
#include <iostream> #include <string> #include <vector> #include <map> struct Document { int id; string content; }; // Function to build an inverted index using a map (Red-Black Tree) std::map<std::string, std::vector<int>> buildInvertedIndex(const std::vector<Document>& documents) { std::map<std::string, std::vector<int>> index; for (const Document& doc : documents) { string word; for (char c : doc.content){ if (isspace(c)){ if (!word.empty()){ index[word].push_back(doc.id); word.clear(); } } else { word += c; } } if (!word.empty()){ // Add the last word index[word].push_back(doc.id); } } return index; } int main() { std::vector<Document> documents = { {1, "The quick brown fox jumps over the lazy dog"}, {2, "A lazy cat sleeps all day long"}, {3, "The quick rabbit jumps over the fence"} }; std::map<std::string, std::vector<int>> invertedIndex = buildInvertedIndex(documents); // Example query: Find documents containing the word "jumps" string searchTerm = "jumps"; if (invertedIndex.find(searchTerm) != invertedIndex.end()) { std::cout << "Documents containing '" << searchTerm << "': "; for (int docId : invertedIndex[searchTerm]) { std::cout << docId << " "; } std::cout << std::endl; } else { std::cout << "No documents found containing '" << searchTerm << "'" << std::endl; } return 0; }

13.4 Sparse Vectors

import java.util.HashMap; import java.util.Map; public class SparseMatrixVectorMultiplication { public static class SparseMatrix { private int rows; private int cols; private Map<String, Double> data; public SparseMatrix(int rows, int cols) { this.rows = rows; this.cols = cols; this.data = new HashMap<>(); } // Method to set a non-zero element in the matrix public void set(int row, int col, double value) { if (row < 0 || row >= rows || col < 0 || col >= cols) { throw new IllegalArgumentException("Invalid row or column index"); } if (value != 0) { data.put(getKey(row, col), value); } } // Method to get an element from the matrix (returns 0 if not present) public double get(int row, int col) { if (row < 0 || row >= rows || col < 0 || col >= cols) { throw new IllegalArgumentException("Invalid row or column index"); } return data.getOrDefault(getKey(row, col), 0.0); } // Helper method to generate key for the HashMap private String getKey(int row, int col) { return row + "," + col; } // Method to perform matrix-vector multiplication public double[] multiply(double[] vector) { if (vector.length != cols) { throw new IllegalArgumentException("Vector size mismatch"); } double[] result = new double[rows]; for (String key : data.keySet()) { String[] indices = key.split(","); int row = Integer.parseInt(indices[0]); int col = Integer.parseInt(indices[1]); result[row] += data.get(key) * vector[col]; } return result; } } }
#include <iostream> #include <unordered_map> #include <vector> // Pair struct to store row and column indices struct RowCol { int row; int col; // Hash function for unordered_map size_t operator()(const RowCol& rc) const { return hash<int>()(rc.row) ^ hash<int>()(rc.col); } // Equality comparison for unordered_map bool operator==(const RowCol& other) const { return row == other.row && col == other.col; } }; class SparseMatrix { private: int rows; int cols; unordered_map<RowCol, double> data; // Symbol table (hash map) public: SparseMatrix(int rows, int cols) : rows(rows), cols(cols) {} // Set a non-zero element in the matrix void set(int row, int col, double value) { if (row < 0 || row >= rows || col < 0 || col >= cols) { throw out_of_range("Invalid row or column index"); } if (value != 0) { data[{row, col}] = value; // Using RowCol struct as key } } // Get an element from the matrix (returns 0 if not present) double get(int row, int col) const { if (row < 0 || row >= rows || col < 0 || col >= cols) { std::throw out_of_range("Invalid row or column index"); } return data.count({row, col}) ? data.at({row, col}) : 0.0; } // Matrix-vector multiplication std::vector<double> multiply(const vector<double>& vec) const { if (vec.size() != cols) { std::throw invalid_argument("Vector size mismatch"); } std::vector<double> result(rows, 0.0); // Initialize result vector with zeros for (const auto& entry : data) { int row = entry.first.row; int col = entry.first.col; result[row] += entry.second * vec[col]; } return result; } };
class SparseMatrix: def __init__(self, rows, cols): self.rows = rows self.cols = cols self.data = {} # Using a dictionary as a symbol table def set(self, row, col, value): if row < 0 or row >= self.rows or col < 0 or col >= self.cols: raise ValueError("Invalid row or column index") if value != 0: self.data[(row, col)] = value def get(self, row, col): if row < 0 or row >= self.rows or col < 0 or col >= self.cols: raise ValueError("Invalid row or column index") return self.data.get((row, col), 0) # Return 0 if not found def multiply(self, vector): if len(vector) != self.cols: raise ValueError("Vector size mismatch") result = [0] * self.rows for (row, col), value in self.data.items(): result[row] += value * vector[col] return result

14 Undirected Graphs

14.1 Introduction to Graphs

Terminology

  1. Graph: Set of vertices connected pairwise by edges.

  2. Path: Sequence of vertices connected by edges.

  3. Cycle: Path whose first and last vertices are the same.

  4. Two vertices are connected if there is a path between them.

undirected graph

14.2 Graph API

Representation Types

  1. Set-of-edges graph representation: Maintain a list of the edges (linked list or array).

  2. Adjacency-matrix graph representation: Maintain a two-dimensional by boolean array; for each edge in the graph: adj[v][w] = adj[w][v] = true.

  3. Adjacency-list graph representation: Maintain vertex-indexed array of lists.

In practice: use adjacency-lists representation.

  • Algorithms based on iterating over vertices adjacent to .

  • Real-world graphs tend to be sparse (huge number of vertices, small average vertex degree).

import java.util.ArrayList; import java.util.LinkedList; import java.util.List; public class UndirectedGraph { private final int numVertices; private final List<List<Integer>> adjacencyList; public UndirectedGraph(int numVertices) { this.numVertices = numVertices; adjacencyList = new ArrayList<>(numVertices); for (int i = 0; i < numVertices; i++) { adjacencyList.add(new LinkedList<>()); } } public void addEdge(int source, int destination) { adjacencyList.get(source).add(destination); adjacencyList.get(destination).add(source); } public int getNumVertices() { return numVertices; } public List<List<Integer>> getAdjacencyList() { return adjacencyList; } public void printGraph() { for (int i = 0; i < numVertices; i++) { System.out.print("Vertex " + i + ":"); for (Integer vertex : adjacencyList.get(i)) { System.out.print(" -> " + vertex); } System.out.println(); } } }
#ifndef UNDIRECTEDGRAPH_H #define UNDIRECTEDGRAPH_H #pragma once #include <vector> #include <list> class UndirectedGraph { private: int numVertices; std::vector<std::list<int>> adjacencyList; public: explicit UndirectedGraph(const int& numVertices); void addEdge(const int& source, const int& destination); [[nodiscard]] bool hasEdge(const int& source, const int& destination) const; [[nodiscard]] int getNumVertices() const; [[nodiscard]] const std::vector<std::list<int>>& getAdjacencyList() const; void printGraph() const; }; #endif //UNDIRECTEDGRAPH_H
#include "UndirectedGraph.h" #include <iostream> #include <algorithm> UndirectedGraph::UndirectedGraph(const int& numVertices) : numVertices(numVertices), adjacencyList(numVertices) {} void UndirectedGraph::addEdge(const int& source, const int& destination) { adjacencyList[source].push_back(destination); adjacencyList[destination].push_back(source); } bool UndirectedGraph::hasEdge(const int& source, const int& destination) const { return std::ranges::any_of(adjacencyList[source], [&destination](const int& neighbor) { return neighbor == destination; }); } int UndirectedGraph::getNumVertices() const { return numVertices; } const std::vector<std::list<int>>& UndirectedGraph::getAdjacencyList() const { return adjacencyList; } void UndirectedGraph::printGraph() const { for (int i = 0; i < numVertices; ++i) { std::cout << "Vertex " << i << ":"; for (const int& neighbor : adjacencyList[i]) { std::cout << " -> " << neighbor; } std::cout << std::endl; } }
class UndirectedGraph: def __init__(self, num_vertices): self.num_vertices = num_vertices self.adjacency_list = [[] for _ in range(num_vertices)] def add_edge(self, source, destination): self.adjacency_list[source].append(destination) self.adjacency_list[destination].append(source) def get_num_vertices(self): return self.num_vertices def get_adjacency_list(self): return self.adjacency_list def print_graph(self): for i in range(self.num_vertices): print(f"Vertex {i}:", end="") for vertex in self.adjacency_list[i]: print(f" -> {vertex}", end="") print()

Goal: Systematically search through a graph.

Typical applications

  • Find all vertices connected to a given source vertex.

  • Find a path between two vertices.

Depth-First Search

  1. Mark vertex as visited.

  2. Recursively visit all the unmarked vertices adjacent to .

Properties

  • DFS marks all vertices connected to in time proportional to the sum of their degrees.

  • After DFS, can find vertices connected to in constant time and can find a path to (if one exists) in time proportional to its length.

import java.util.Stack; public class DepthFirstSearch { private final boolean[] marked; private final int[] edgeTo; public DepthFirstSearch(UndirectedGraph graph, int source) { this.marked = new boolean[graph.getNumVertices()]; this.edgeTo = new int[graph.getNumVertices()]; dfs(graph, source); } private void dfs(UndirectedGraph graph, int source) { Stack<Integer> stack = new Stack<>(); marked[source] = true; stack.push(source); while (!stack.isEmpty()) { int v = stack.pop(); System.out.print(v + " "); for (int w : graph.getAdjacencyList().get(v)) { if (!marked[w]) { marked[w] = true; edgeTo[w] = v; stack.push(w); } } } } public void hasPathTo(int v) { return marked[v]; } public void printPathTo(int v) { if (!marked[v]) { System.out.println("No path from source to " + v); return; } Stack<Integer> path = new Stack<>(); for (int x = v; x != 0; x = edgeTo[x]) { path.push(x); } path.push(0); System.out.print("Path: "); while (!path.isEmpty()) { System.out.print(path.pop()); if (!path.isEmpty()) { System.out.print(" -> "); } } System.out.println(); } }
#ifndef DEPTHFIRSTSEARCH_H #define DEPTHFIRSTSEARCH_H #pragma once #include <vector> #include "UndirectedGraph.h" class DepthFirstSearch { private: const UndirectedGraph& graph; std::vector<bool> marked; std::vector<int> edgeTo; public: DepthFirstSearch(const UndirectedGraph& graph, int source); void dfs(int v); [[nodiscard]] bool hasPathTo(int v) const; void printPathTo(int v) const; }; #endif //DEPTHFIRSTSEARCH_H
#include "DepthFirstSearch.h" #include <iostream> #include <stack> DepthFirstSearch::DepthFirstSearch(const UndirectedGraph& graph, const int source) : graph(graph), marked(graph.getNumVertices(), false), edgeTo(graph.getNumVertices(), -1) { dfs(source); } void DepthFirstSearch::dfs(const int v) { std::stack<int> stack; marked[v] = true; stack.push(v); while (!stack.empty()) { const int current = stack.top(); stack.pop(); std::cout << current << " "; for (int neighbor : this->graph.getAdjacencyList()[current]) { if (!marked[neighbor]) { marked[neighbor] = true; edgeTo[neighbor] = current; stack.push(neighbor); } } } } bool DepthFirstSearch::hasPathTo(const int v) const { return marked[v]; } void DepthFirstSearch::printPathTo(const int v) const { if (!hasPathTo(v)) { std::cout << "No path from source to " << v << std::endl; return; } std::stack<int> path; for (int x = v; x != edgeTo[v]; x = edgeTo[x]) { path.push(x); } path.push(edgeTo[v]); std::cout << "Path: "; while (!path.empty()) { std::cout << path.top(); path.pop(); if (!path.empty()) { std::cout << " -> "; } } std::cout << std::endl; }
from UndirectedGraph import UndirectedGraph class DepthFirstSearch: def __init__(self, graph: UndirectedGraph, source: int): self.marked = [False] * graph.get_num_vertices() self.edge_to = [None] * graph.get_num_vertices() self.dfs(graph, source) def dfs(self, graph, source): stack = [source] self.marked[source] = True while stack: v = stack.pop() print(v, end=" ") for w in graph.get_adjacency_list()[v]: if not self.marked[w]: self.marked[w] = True self.edge_to[w] = v stack.append(w) def has_path_to(self, v): return self.marked[v] def print_path_to(self, v): if not self.marked[v]: print(f"No path from source to {v}") return path = [] x = v while x is not None: path.append(x) x = self.edge_to[x] print("Path:", " -> ".join(map(str, path[::-1])))

Breadth-First Search

  1. Put s onto a FIFO queue, and mark s as visited.

  2. Repeat until the queue is empty.

  3. Remove the least recently added vertex v.

  4. Add each of v's unvisited neighbors to the queue, and mark them as visited.

Property

BFS computes shortest paths (fewest number of edges) from s to all other vertices in a graph in time proportional to .

  • Depth-first search: put unvisited vertices on stack.

  • Breadth-first search: put unvisited vertices on queue.

import java.util.ArrayDeque; import java.util.List; import java.util.Queue; public class BreadthFirstSearch { private boolean[] marked; private int[] edgeTo; private int[] distanceTo; public void bfs(UndirectedGraph graph, int startVertex) { marked = new boolean[graph.getNumVertices()]; edgeTo = new int[graph.getNumVertices()]; distanceTo = new int[graph.getNumVertices()]; Queue<Integer> queue = new ArrayDeque<>(); marked[startVertex] = true; distanceTo[startVertex] = 0; queue.offer(startVertex); while (!queue.isEmpty()) { int currentVertex = queue.poll(); List<List<Integer>> adjList = graph.getAdjacencyList(); for (int adjacentVertex : adjList.get(currentVertex)) { if (!marked[adjacentVertex]) { marked[adjacentVertex] = true; edgeTo[adjacentVertex] = currentVertex; distanceTo[adjacentVertex] = distanceTo[currentVertex] + 1; // Update distance queue.offer(adjacentVertex); } } } } public int getDistance(int destination) { if (!marked[destination]) { return -1; } return distanceTo[destination]; } public void printPath(int start, int end) { if (start == end) { System.out.print(start); return; } if (edgeTo[end] == 0) { System.out.print("No path exists"); return; } printPath(start, edgeTo[end]); System.out.print(" -> " + end); } }
#ifndef BREADTHFIRSTSEARCH_H #define BREADTHFIRSTSEARCH_H #pragma once #include "UndirectedGraph.h" #include <vector> class BreadthFirstSearch { private: const UndirectedGraph& graph; int startVertex; std::vector<bool> marked; std::vector<int> edgeTo; std::vector<int> distanceTo; public: BreadthFirstSearch(const UndirectedGraph& graph, int startVertex); void bfs(); [[nodiscard]] int getDistance(int destination) const; void printPath(int destination) const; }; #endif //BREADTHFIRSTSEARCH_H
#include "BreadthFirstSearch.h" #include <iostream> #include <queue> BreadthFirstSearch::BreadthFirstSearch(const UndirectedGraph& graph, const int startVertex) : graph(graph), startVertex(startVertex), marked(graph.getNumVertices(), false), edgeTo(graph.getNumVertices(), -1), distanceTo(graph.getNumVertices(), 0) {} void BreadthFirstSearch::bfs() { std::queue<int> queue; marked[startVertex] = true; queue.push(startVertex); while (!queue.empty()) { const int currentVertex = queue.front(); queue.pop(); for (const int& neighbor : graph.getAdjacencyList()[currentVertex]) { if (!marked[neighbor]) { marked[neighbor] = true; edgeTo[neighbor] = currentVertex; distanceTo[neighbor] = distanceTo[currentVertex] + 1; queue.push(neighbor); } } } } int BreadthFirstSearch::getDistance(const int destination) const { if (!marked[destination]) { return -1; } return distanceTo[destination]; } void BreadthFirstSearch::printPath(const int destination) const { if (startVertex == destination) { std::cout << startVertex; return; } if (edgeTo[destination] == -1) { std::cout << "No path exists"; return; } printPath(edgeTo[destination]); std::cout << " -> " << destination; }
from collections import deque from UndirectedGraph import UndirectedGraph import sys class BreadthFirstSearch: def __init__(self, graph: UndirectedGraph, start_vertex: int): self.marked = [False] * graph.get_num_vertices() self.edge_to = [None] * graph.get_num_vertices() self.distance_to = [sys.maxsize] * graph.get_num_vertices() self.bfs(graph, start_vertex) def bfs(self, graph: UndirectedGraph, start_vertex: int): queue = deque([start_vertex]) self.marked[start_vertex] = True self.distance_to[start_vertex] = 0 while queue: current_vertex = queue.popleft() for adjacent_vertex in graph.get_adjacency_list()[current_vertex]: if not self.marked[adjacent_vertex]: self.marked[adjacent_vertex] = True self.edge_to[adjacent_vertex] = current_vertex self.distance_to[adjacent_vertex] = self.distance_to[current_vertex] + 1 queue.append(adjacent_vertex) def get_distance(self, destination: int) -> int: if not self.marked[destination]: return -1 return self.distance_to[destination] def print_path(self, start: int, end: int): if start == end: print(start, end="") return if self.edge_to[end] is None: print("No path exists") return self.print_path(start, self.edge_to[end]) print(f" -> {end}", end="")

14.5 Connected Components

Connected Components: A connected component is maximal set of connected vertices.

Find all Connected Components

  1. Mark vertex as visited.

  2. Recursively visit all the unmarked vertices adjacent to .

import java.util.ArrayList; import java.util.List; public class ConnectedComponents { private final int[] id; private int count; public ConnectedComponents(UndirectedGraph graph) { int numVertices = graph.getNumVertices(); id = new int[numVertices]; count = 0; for (int i = 0; i < numVertices; i++) { id[i] = i; } for (int i = 0; i < numVertices; i++) { if (id[i] == i) { dfs(graph, i); count++; } } } private void dfs(UndirectedGraph graph, int v) { id[v] = count; for (int w : graph.getAdjacencyList().get(v)) { if (id[w] == w) { dfs(graph, w); } } } public boolean isConnected(int v, int w) { return id[v] == id[w]; } public int getCount() { return count; } public void printComponents() { System.out.println("Number of connected components: " + count); List<List<Integer>> components = new ArrayList<>(count); for (int i = 0; i < count; i++) { components.add(new ArrayList<>()); } for (int i = 0; i < id.length; i++) { components.get(id[i]).add(i); } for (int i = 0; i < count; i++) { System.out.println("Component " + i + ": " + components.get(i)); } } }
#ifndef CONNECTEDCOMPONENTS_H #define CONNECTEDCOMPONENTS_H #include <vector>> #include "UndirectedGraph.h" class ConnectedComponents { private: std::vector<int> id; int count; void dfs(const UndirectedGraph& graph, int v); public: explicit ConnectedComponents(const UndirectedGraph& graph); [[nodiscard]] bool isConnected(int v, int w) const; [[nodiscard]] int getCount() const; void printComponents() const; }; #endif // CONNECTEDCOMPONENTS_H
#include "ConnectedComponents.h" #include <iostream> ConnectedComponents::ConnectedComponents(const UndirectedGraph& graph) : count(0) { const int numVertices = graph.getNumVertices(); id.resize(numVertices); for (int i = 0; i < numVertices; ++i) { id[i] = i; } for (int i = 0; i < numVertices; ++i) { if (id[i] == i) { dfs(graph, i); ++count; } } } void ConnectedComponents::dfs(const UndirectedGraph& graph, int v) { id[v] = count; for (const int& w : graph.getAdjacencyList()[v]) { if (id[w] == w) { dfs(graph, w); } } } bool ConnectedComponents::isConnected(int v, int w) const { return id[v] == id[w]; } int ConnectedComponents::getCount() const { return count; } void ConnectedComponents::printComponents() const { std::cout << "Number of connected components: " << count << std::endl; std::vector<std::vector<int>> components(count); for (int i = 0; i < id.size(); ++i) { components[id[i]].push_back(i); } for (int i = 0; i < count; ++i) { std::cout << "Component " << i << ": "; for (const int& vertex : components[i]) { std::cout << vertex << " "; } std::cout << std::endl; } }
from UndirectedGraph import UndirectedGraph class ConnectedComponents: def __init__(self, graph: UndirectedGraph): self.id = list(range(graph.get_num_vertices())) self.count = 0 for i in range(graph.get_num_vertices()): if self.id[i] == i: self.dfs(graph, i) self.count += 1 def dfs(self, graph: UndirectedGraph, v: int): self.id[v] = self.count for w in graph.get_adjacency_list()[v]: if self.id[w] == w: self.dfs(graph, w) def is_connected(self, v: int, w: int) -> bool: return self.id[v] == self.id[w] def get_count(self) -> int: return self.count def print_components(self): print("Number of connected components:", self.count) components = [[] for _ in range(self.count)] for i in range(len(self.id)): components[self.id[i]].append(i) for i, component in enumerate(components): print(f"Component {i}: {component}")

14.6 Important Questions

  1. Q: Implement depth-first search in an undirected graph without using recursion.

    A: Simply replace a queue with a stack in breadth-first search.

  2. Given a connected graph with no cycles

    • Q: Diameter: design a linear-time algorithm to find the longest simple path in the graph.

      A: to compute the diameter, pick a vertex ; run BFS from ; then run BFS again from the vertex that is furthest from .

    • Q: Center: design a linear-time algorithm to find the center of the graph.

      A: Consider vertices on the longest path.

  3. Q: An Euler cycle in a graph is a cycle (not necessarily simple) that uses every edge in the graph exactly one. Design a linear-time algorithm to determine whether a graph has an Euler cycle, and if so, find one.

    A: use depth-first search and piece together the cycles you discover.

15 Directed Graphs

15.1 Introduction to Directed Graphs

Directed graph: Set of vertices connected pairwise by directed edges.

Directed graph

15.2 Directed Graph API

import java.util.ArrayList; import java.util.List; public class DirectedGraph { private final int numVertices; private int numEdges; private final List<List<Integer>> adjacencyList; public DirectedGraph(int numVertices) { this.numVertices = numVertices; this.numEdges = 0; this.adjacencyList = new ArrayList<>(); for (int i = 0; i < numVertices; i++) { adjacencyList.add(i, new ArrayList<>()); } } public void addEdge(int source, int destination) { adjacencyList.get(source).add(destination); numEdges++; } public int getNumVertices() { return numVertices; } public int getNumEdges() { return numEdges; } public List<List<Integer>> getAdjacencyList() { return adjacencyList; } public void printGraph() { for (int v = 0; v < numVertices; v++) { System.out.print("Adjacency list of vertex " + v + " : "); for (Integer neighbor : adjacencyList.get(v)) { System.out.print(neighbor + " "); } System.out.println(); } } }
#ifndef DIRECTEDGRAPH_H #define DIRECTEDGRAPH_H #pragma once #include <vector> #include <list> class DirectedGraph { private: int numVertices; std::vector<std::list<int>> adjacencyList; public: explicit DirectedGraph(const int& numVertices); void addEdge(const int& source, const int& destination); [[nodiscard]] bool hasEdge(const int& source, const int& destination) const; [[nodiscard]] int getNumVertices() const; [[nodiscard]] const std::vector<std::list<int>>& getAdjacencyList() const; void printGraph() const; }; #endif //DIRECTEDGRAPH_H
#include "DirectedGraph.h" #include <iostream> #include <algorithm> DirectedGraph::DirectedGraph(const int& numVertices) : numVertices(numVertices), adjacencyList(numVertices) {} void DirectedGraph::addEdge(const int& source, const int& destination) { adjacencyList[source].push_back(destination); } bool DirectedGraph::hasEdge(const int& source, const int& destination) const { return std::ranges::any_of(adjacencyList[source], [&destination](const int& neighbor) { return neighbor == destination; }); } int DirectedGraph::getNumVertices() const { return numVertices; } const std::vector<std::list<int>>& DirectedGraph::getAdjacencyList() const { return adjacencyList; } void DirectedGraph::printGraph() const { for (int i = 0; i < numVertices; ++i) { std::cout << "Vertex " << i << ":"; for (const int& neighbor : adjacencyList[i]) { std::cout << " -> " << neighbor; } std::cout << std::endl; } }
class DirectedGraph: def __init__(self, num_vertices): self.num_vertices = num_vertices self.num_edges = 0 self.adjacency_list = [[] for _ in range(num_vertices)] def add_edge(self, source, destination): self.adjacency_list[source].append(destination) self.num_edges += 1 def get_num_vertices(self): return self.num_vertices def get_num_edges(self): return self.num_edges def print_graph(self): for v in range(self.num_vertices): print(f"Adjacency list of vertex {v} : ", end="") for neighbor in self.adjacency_list[v]: print(f"{neighbor} ", end="") print()

15.3.1 Depth-First Search for Digraph

import java.util.Stack; public class DirectedDepthFirstSearch { private final boolean[] marked; private final int[] edgeTo; public DirectedDepthFirstSearch(DirectedGraph graph, int source) { this.marked = new boolean[graph.getNumVertices()]; this.edgeTo = new int[graph.getNumVertices()]; dfs(graph, source); } private void dfs(DirectedGraph graph, int source) { Stack<Integer> stack = new Stack<>(); marked[source] = true; stack.push(source); while (!stack.isEmpty()) { int v = stack.pop(); System.out.print(v + " "); for (int w : graph.getAdjacencyList().get(v)) { if (!marked[w]) { marked[w] = true; edgeTo[w] = v; stack.push(w); } } } } public boolean hasPathTo(int v) { return marked[v]; } public void printPathTo(int v) { if (!marked[v]) { System.out.println("No path from source to " + v); return; } Stack<Integer> path = new Stack<>(); for (int x = v; x != 0; x = edgeTo[x]) { path.push(x); } path.push(0); System.out.print("Path: "); while (!path.isEmpty()) { System.out.print(path.pop()); if (!path.isEmpty()) { System.out.print(" -> "); } } System.out.println(); } }
#ifndef DIRECTEDDEPTHFIRSTSEARCH_H #define DIRECTEDDEPTHFIRSTSEARCH_H #pragma once #include <vector> #include "DirectedGraph.h" class DirectedDepthFirstSearch { private: const DirectedGraph& graph; std::vector<bool> marked; std::vector<int> edgeTo; public: DirectedDepthFirstSearch(const DirectedGraph& graph, int source); void dfs(int v); [[nodiscard]] bool hasPathTo(int v) const; void printPathTo(int v) const; }; #endif //DIRECTEDDEPTHFIRSTSEARCH_H
#include "DirectedDepthFirstSearch.h" #include <iostream> #include <stack> DirectedDepthFirstSearch::DirectedDepthFirstSearch(const DirectedGraph& graph, const int source) : graph(graph), marked(graph.getNumVertices(), false), edgeTo(graph.getNumVertices(), -1) { dfs(source); } void DirectedDepthFirstSearch::dfs(const int v) { std::stack<int> stack; marked[v] = true; stack.push(v); while (!stack.empty()) { const int current = stack.top(); stack.pop(); std::cout << current << " "; for (int neighbor : this->graph.getAdjacencyList()[current]) { if (!marked[neighbor]) { marked[neighbor] = true; edgeTo[neighbor] = current; stack.push(neighbor); } } } } bool DirectedDepthFirstSearch::hasPathTo(const int v) const { return marked[v]; } void DirectedDepthFirstSearch::printPathTo(const int v) const { if (!hasPathTo(v)) { std::cout << "No path from source to " << v << std::endl; return; } std::stack<int> path; for (int x = v; x != edgeTo[v]; x = edgeTo[x]) { path.push(x); } path.push(edgeTo[v]); std::cout << "Path: "; while (!path.empty()) { std::cout << path.top(); path.pop(); if (!path.empty()) { std::cout << " -> "; } } std::cout << std::endl; }
from DirectedGraph import DirectedGraph class DirectedDepthFirstSearch: def __init__(self, graph: DirectedGraph, source: int): self.marked = [False] * graph.get_num_vertices() self.edge_to = [None] * graph.get_num_vertices() self.dfs(graph, source) def dfs(self, graph, source): stack = [source] self.marked[source] = True while stack: v = stack.pop() print(v, end=" ") for w in graph.get_adjacency_list()[v]: if not self.marked[w]: self.marked[w] = True self.edge_to[w] = v stack.append(w) def has_path_to(self, v): return self.marked[v] def print_path_to(self, v): if not self.marked[v]: print(f"No path from source to {v}") return path = [] x = v while x is not None: path.append(x) x = self.edge_to[x] print("Path:", " -> ".join(map(str, path[::-1])))

15.3.2 Breadth-First Search for Digraph

Reachability application:

  • Program control-flow analysis

  • Mark-sweep garbage collector: if ao object is unreachable, it is garbage.

Application

  • Web crawler

import java.util.ArrayDeque; import java.util.List; import java.util.Queue; public class BreadthFirstSearch { private boolean[] marked; private int[] edgeTo; private int[] distanceTo; public void bfs(UndirectedGraph graph, int startVertex) { marked = new boolean[graph.getNumVertices()]; edgeTo = new int[graph.getNumVertices()]; distanceTo = new int[graph.getNumVertices()]; Queue<Integer> queue = new ArrayDeque<>(); marked[startVertex] = true; distanceTo[startVertex] = 0; queue.offer(startVertex); while (!queue.isEmpty()) { int currentVertex = queue.poll(); List<List<Integer>> adjList = graph.getAdjacencyList(); for (int adjacentVertex : adjList.get(currentVertex)) { if (!marked[adjacentVertex]) { marked[adjacentVertex] = true; edgeTo[adjacentVertex] = currentVertex; distanceTo[adjacentVertex] = distanceTo[currentVertex] + 1; // Update distance queue.offer(adjacentVertex); } } } } public int getDistance(int destination) { if (!marked[destination]) { return -1; } return distanceTo[destination]; } public void printPath(int start, int end) { if (start == end) { System.out.print(start); return; } if (edgeTo[end] == 0) { System.out.print("No path exists"); return; } printPath(start, edgeTo[end]); System.out.print(" -> " + end); } }
#ifndef BREADTHFIRSTSEARCH_H #define BREADTHFIRSTSEARCH_H #pragma once #include "UndirectedGraph.h" #include <vector> class BreadthFirstSearch { private: const UndirectedGraph& graph; int startVertex; std::vector<bool> marked; std::vector<int> edgeTo; std::vector<int> distanceTo; public: BreadthFirstSearch(const UndirectedGraph& graph, int startVertex); void bfs(); [[nodiscard]] int getDistance(int destination) const; void printPath(int destination) const; }; #endif //BREADTHFIRSTSEARCH_H
#include "BreadthFirstSearch.h" #include <iostream> #include <queue> BreadthFirstSearch::BreadthFirstSearch(const UndirectedGraph& graph, const int startVertex) : graph(graph), startVertex(startVertex), marked(graph.getNumVertices(), false), edgeTo(graph.getNumVertices(), -1), distanceTo(graph.getNumVertices(), 0) {} void BreadthFirstSearch::bfs() { std::queue<int> queue; marked[startVertex] = true; queue.push(startVertex); while (!queue.empty()) { const int currentVertex = queue.front(); queue.pop(); for (const int& neighbor : graph.getAdjacencyList()[currentVertex]) { if (!marked[neighbor]) { marked[neighbor] = true; edgeTo[neighbor] = currentVertex; distanceTo[neighbor] = distanceTo[currentVertex] + 1; queue.push(neighbor); } } } } int BreadthFirstSearch::getDistance(const int destination) const { if (!marked[destination]) { return -1; } return distanceTo[destination]; } void BreadthFirstSearch::printPath(const int destination) const { if (startVertex == destination) { std::cout << startVertex; return; } if (edgeTo[destination] == -1) { std::cout << "No path exists"; return; } printPath(edgeTo[destination]); std::cout << " -> " << destination; }
from collections import deque from DirectedGraph import DirectedGraph import sys class BreadthFirstSearch: def __init__(self, graph: DirectedGraph, start_vertex: int): self.marked = [False] * graph.get_num_vertices() self.edge_to = [None] * graph.get_num_vertices() self.distance_to = [sys.maxsize] * graph.get_num_vertices() self.bfs(graph, start_vertex) def bfs(self, graph: UndirectedGraph, start_vertex: int): queue = deque([start_vertex]) self.marked[start_vertex] = True self.distance_to[start_vertex] = 0 while queue: current_vertex = queue.popleft() for adjacent_vertex in graph.get_adjacency_list()[current_vertex]: if not self.marked[adjacent_vertex]: self.marked[adjacent_vertex] = True self.edge_to[adjacent_vertex] = current_vertex self.distance_to[adjacent_vertex] = self.distance_to[current_vertex] + 1 queue.append(adjacent_vertex) def get_distance(self, destination: int) -> int: if not self.marked[destination]: return -1 return self.distance_to[destination] def print_path(self, start: int, end: int): if start == end: print(start, end="") return if self.edge_to[end] is None: print("No path exists") return self.print_path(start, self.edge_to[end]) print(f" -> {end}", end="")

15.4 Topological Sort

DAG: Directed Acyclic Graph.

Topological sort: Redraw DAG so all edges point upwards.

Property: A digraph has a topological order iff no directed cycle.

Application: Precedence scheduling, cycle inheritance, spreadsheet recalculation, etc.

Topological Sort with DFS

  1. Run depth-first search.

  2. Return vertices in reverse postorder.

import java.util.ArrayList; import java.util.List; import java.util.Stack; public class TopologicalSort { private final DirectedGraph graph; private final boolean[] visited; private final Stack<Integer> postorder; public TopologicalSort(DirectedGraph graph) { this.graph = graph; this.visited = new boolean[graph.getNumVertices()]; this.postorder = new Stack<>(); } public List<Integer> topologicalSort() { for (int v = 0; v < graph.getNumVertices(); v++) { if (!visited[v]) { dfs(v); } } List<Integer> sortedVertices = new ArrayList<>(); while (!postorder.isEmpty()) { sortedVertices.add(postorder.pop()); } return sortedVertices; } private void dfs(int v) { visited[v] = true; for (Integer neighbor : graph.getAdjacencyList().get(v)) { if (!visited[neighbor]) { dfs(neighbor); } } postorder.push(v); } }
#ifndef TOPOLOGICALSORT_H #define TOPOLOGICALSORT_H #pragma once #include "DirectedGraph.h" #include <vector> #include <stack> class TopologicalSort { private: const DirectedGraph& graph; std::vector<bool> visited; std::stack<int> postorder; void dfs(int v); public: explicit TopologicalSort(const DirectedGraph& graph); std::vector<int> topologicalSort(); }; #endif //TOPOLOGICALSORT_H
#include "TopologicalSort.h" TopologicalSort::TopologicalSort(const DirectedGraph &graph) : graph(graph), visited(graph.getNumVertices(), false) {} void TopologicalSort::dfs(const int v) { visited[v] = true; for (const int& neighbor : graph.getAdjacencyList()[v]) { if (!visited[neighbor]) { dfs(neighbor); } } postorder.push(v); } std::vector<int> TopologicalSort::topologicalSort() { for (int v = 0; v < graph.getNumVertices(); ++v) { if (!visited[v]) { dfs(v); } } std::vector<int> sortedVertices; while (!postorder.empty()) { sortedVertices.push_back(postorder.top()); postorder.pop(); } return sortedVertices; }
class TopologicalSort: def __init__(self, graph): self.graph = graph # Store the DirectedGraph object self.visited = [False] * self.graph.get_num_vertices() self.postorder = [] def topological_sort(self): for v in range(self.graph.get_num_vertices()): if not self.visited[v]: self.dfs(v) return self.postorder[::-1] def dfs(self, v): self.visited[v] = True for neighbor in self.graph.adjacency_list[v]: if not self.visited[neighbor]: self.dfs(neighbor) self.postorder.append(v)

15.4.2 Algorithm Ⅱ - Kahn's Algorithm

Topological Sort with Kahn's Algorithm

  1. Calculate in-degrees.

  2. Find nodes with in-degree 0.

  3. Process nodes in topological order, and decrement in-degree of neighbors.

import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Queue; public class TopologicalSort { public static List<Integer> topologicalSort(DirectedGraph graph) { int numVertices = graph.getNumVertices(); List<List<Integer>> adjList = graph.getAdjacencyList(); int[] inDegree = new int[numVertices]; for (int u = 0; u < numVertices; u++) { for (int v : adjList.get(u)) { inDegree[v]++; } } Queue<Integer> queue = new LinkedList<>(); for (int i = 0; i < numVertices; i++) { if (inDegree[i] == 0) { queue.offer(i); } } List<Integer> sortedOrder = new ArrayList<>(); while (!queue.isEmpty()) { int u = queue.poll(); sortedOrder.add(u); for (int v : adjList.get(u)) { if (--inDegree[v] == 0) { queue.offer(v); } } } if (sortedOrder.size() != numVertices) { System.err.println("Error: Graph contains a cycle!"); return new ArrayList<>(); } return sortedOrder; } }
#include "DirectedGraph.h" #include <iostream> #include <queue> #include <vector> std::vector<int> topologicalSort(const DirectedGraph& graph) { const int numVertices = graph.getNumVertices(); std::vector<int> inDegree(numVertices, 0); std::vector<int> sortedOrder; std::queue<int> queue; for (int u = 0; u < numVertices; ++u) { for (const int& v : graph.getAdjacencyList()[u]) { inDegree[v]++; } } for (int i = 0; i < numVertices; ++i) { if (inDegree[i] == 0) { queue.push(i); } } while (!queue.empty()) { int u = queue.front(); queue.pop(); sortedOrder.push_back(u); for (const int& v : graph.getAdjacencyList()[u]) { if (--inDegree[v] == 0) { queue.push(v); } } } // Check for cycles! if (sortedOrder.size() != numVertices) { std::cerr << "Error: Graph contains a cycle!" << std::endl; return {}; } return sortedOrder; }
def topological_sort(graph): num_vertices = graph.get_num_vertices() in_degree = [0] * num_vertices sorted_order = [] queue = [] for u in range(num_vertices): for v in graph.adjacency_list[u]: in_degree[v] += 1 for i in range(num_vertices): if in_degree[i] == 0: queue.append(i) while queue: u = queue.pop(0) sorted_order.append(u) for v in graph.adjacency_list[u]: in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(sorted_order) != num_vertices: return None return sorted_order

15.5 Strong Components

Connected Components

Strongly-Connected Components

Definition

and are connected if there is a path between and

and are strongly connected if there is a directed path from to and a directed graph from to

Implementation

DFS

DFS & Reverse DFS

Detail

Connected Components
Strongly-Connected Components

Strongly-Connected Components

  1. Computer topological order (reverse postorder) in kernel DAG.

  2. Run DFS, considering vertices in reverse topological order.

import java.util.Stack; public class StronglyConnectedComponents { private final DirectedGraph graph; private boolean[] visited; private final Stack<Integer> stack; private int sccCount; public StronglyConnectedComponents(DirectedGraph graph) { this.graph = graph; this.visited = new boolean[graph.getNumVertices()]; this.stack = new Stac<>(); this.sccCount = 0; } public void findStronglyConnectedComponents() { for (int i = 0; i < graph.getNumVertices(); i++) { if (!visited[i]) { dfsFirst(i); } } DirectedGraph transposedGraph = transposeGraph(); visited = new boolean[graph.getNumVertices()]; while (!stack.isEmpty()) { int vertex = stack.pop(); if (!visited[vertex]) { sccCount++; System.out.print("SCC " + sccCount + ": "); dfsSecond(transposedGraph, vertex); System.out.println(); } } } private void dfsFirst(int vertex) { visited[vertex] = true; for (int neighbor : graph.getAdjacencyList().get(vertex)) { if (!visited[neighbor]) { dfsFirst(neighbor); } } stack.push(vertex); } private void dfsSecond(DirectedGraph transposedGraph, int vertex) { visited[vertex] = true; System.out.print(vertex + " "); for (int neighbor : transposedGraph.getAdjacencyList().get(vertex)) { if (!visited[neighbor]) { dfsSecond(transposedGraph, neighbor); } } } private DirectedGraph transposeGraph() { DirectedGraph transposedGraph = new DirectedGraph(graph.getNumVertices()); for (int i = 0; i < graph.getNumVertices(); i++) { for (int neighbor : graph.getAdjacencyList().get(i)) { transposedGraph.addEdge(neighbor, i); } } return transposedGraph; } }
#ifndef STRONGLYCONNECTEDCOMPONENTS_H #define STRONGLYCONNECTEDCOMPONENTS_H #pragma once #include "DirectedGraph.h" #include <vector> #include <stack> class StronglyConnectedComponents { private: const DirectedGraph& graph; std::vector<bool> visited; std::stack<int> finishingStack; int sccCount; void dfsFirst(int vertex); void dfsSecond(const DirectedGraph& transposedGraph, int vertex); public: explicit StronglyConnectedComponents(const DirectedGraph& graph); void findStronglyConnectedComponents(); }; #endif //STRONGLYCONNECTEDCOMPONENTS_H
#include "StronglyConnectedComponents.h" #include <iostream> StronglyConnectedComponents::StronglyConnectedComponents(const DirectedGraph& graph) : graph(graph), visited(graph.getNumVertices(), false), sccCount(0) {} void StronglyConnectedComponents::dfsFirst(int vertex) { visited[vertex] = true; for (const int& neighbor : graph.getAdjacencyList()[vertex]) { if (!visited[neighbor]) { dfsFirst(neighbor); } } finishingStack.push(vertex); } void StronglyConnectedComponents::dfsSecond(const DirectedGraph& transposedGraph, int vertex) { visited[vertex] = true; std::cout << vertex << " "; for (const int& neighbor : transposedGraph.getAdjacencyList()[vertex]) { if (!visited[neighbor]) { dfsSecond(transposedGraph, neighbor); } } } void StronglyConnectedComponents::findStronglyConnectedComponents() { for (int i = 0; i < graph.getNumVertices(); ++i) { if (!visited[i]) { dfsFirst(i); } } DirectedGraph transposedGraph(graph.getNumVertices()); for (int i = 0; i < graph.getNumVertices(); ++i) { for (const int& neighbor : graph.getAdjacencyList()[i]) { transposedGraph.addEdge(neighbor, i); } } visited.assign(graph.getNumVertices(), false); while (!finishingStack.empty()) { int vertex = finishingStack.top(); finishingStack.pop(); if (!visited[vertex]) { std::cout << "SCC " << ++sccCount << ": "; dfsSecond(transposedGraph, vertex); std::cout << std::endl; } } }
class StronglyConnectedComponents: def __init__(self, graph): self.graph = graph self.visited = [False] * graph.num_vertices self.finishing_stack = [] self.scc_count = 0 def dfs_first(self, vertex): self.visited[vertex] = True for neighbor in self.graph.adjacency_list[vertex]: if not self.visited[neighbor]: self.dfs_first(neighbor) self.finishing_stack.append(vertex) def dfs_second(self, transposed_graph, vertex): self.visited[vertex] = True print(f"{vertex} ", end="") for neighbor in transposed_graph.adjacency_list[vertex]: if not self.visited[neighbor]: self.dfs_second(transposed_graph, neighbor) def find_strongly_connected_components(self): # 1. DFS on original graph to get finishing times for i in range(self.graph.num_vertices): if not self.visited[i]: self.dfs_first(i) transposed_graph = DirectedGraph(self.graph.num_vertices) for i in range(self.graph.num_vertices): for neighbor in self.graph.adjacency_list[i]: transposed_graph.add_edge(neighbor, i) self.visited = [False] * self.graph.num_vertices while self.finishing_stack: vertex = self.finishing_stack.pop() if not self.visited[vertex]: self.scc_count += 1 print(f"SCC {self.scc_count}: ", end="") self.dfs_second(transposed_graph, vertex) print()
Last modified: 25 November 2024