Ruby で二分探索木

勉強のために Ruby二分探索木を書いてみました。
二分探索木の Ruby 実装 · GitHub
なお、以下のコードは説明のためなので、Gist のコードを多少簡略化しています。
 
木(Tree)のクラスと要素(Node)のクラスを定義します。Node は Struct クラスを使って簡単に定義します。key と value の値が使えます。Node を追加するたび、key の順序で整序します。

Node = Struct.new(:key, :value, :left, :right)

Tree クラスは、要素の追加と探索はこんな風になります。

class Tree
  def initialize
    @root = nil
  end
  
  def insert(key, value)
    unless @root
      @root = Node.new(key, value)
      return
    end
    node = @root
    while node
      if key < node.key
        if node.left
          node = node.left
        else
          node.left = Node.new(key, value)
          break
        end
      elsif key > node.key
        if node.right
          node = node.right
        else
          node.right = Node.new(key, value)
          break
        end
      else
        if block_given?
          node.value = yield(key, node.value, value)
        else
          node.value = value
        end
        break
      end
    end
  end
  
  def []=(key, value)
    insert(key, value)
  end
  
  # nodeを返す
  def search(key)
    node = @root
    while node
      if key < node.key
        node = node.left
      elsif key > node.key
        node = node.right
      else
        return node
      end
    end
    nil
  end
  
  # valueを返す
  def [](key)
    search(key)&.value
  end

 
これだけでこんな風に使えます。まず、t[key] = value という具合に、Node を登録しています。

$ pry
[1] pry(main)> require './binary_search_tree'
=> true
[2] pry(main)> t = Tree.new
=> #<Tree:@root=>
[3] pry(main)> t[5] = :いちご
=> :いちご
[4] pry(main)> t[3] = :みかん
=> :みかん
[5] pry(main)> t[10] = :りんご
=> :りんご
[6] pry(main)> t
=> #<Tree:@root=#<Node:@key='5', @value='いちご', @left='#<Node:@key='3', @value='みかん'>', @right='#<Node:@key='10', @value='りんご'>'>>

t[key] でハッシュのように value を見ることができます。

[7] pry(main)> t[10]
=> :りんご
[8] pry(main)> t[5]
=> :いちご

メソッド print で木を簡単に出力します。

[9] pry(main)> t.print
 -- 5 -- 3
      -- 10

さらに Node を追加します。key の順序で自動的に整序されているのがわかると思います。

[10] pry(main)> t[4] = :柿
=> :柿
[11] pry(main)> t.print
 -- 5 -- 3 --
           -- 4
      -- 10

メソッド each で木をトラバースします。

[12] pry(main)> t.each {|k, v| puts "#{k} --> #{v}"}
3 --> みかん
4 --> 柿
5 --> いちご
10 --> りんご

なお、each メソッドが定義されているので、Enumerable を include してあります。Enumerable のすべてのメソッドが使えます。

むずかしいのは Node の削除です。二分探索木の特性を壊さないように削除しなければなりません。コードはこんな感じ。

class Tree
  # Nodeを削除する
  def delete(key)
    delete_min = ->(node) {
      return node.right unless node.left
      node.left = delete_min.(node.left)
      node
    }
    del = ->(node) {
      value = nil
      if node
        if key == node.key
          value = node.value
          if node.left.nil?
            return node.right, value
          elsif node.right.nil?
            return node.left, value
          else
            min_node = search_min(node.right)
            node.key = min_node.key
            node.value = min_node.value
            node.right = delete_min.(node.right)
          end
        elsif key < node.key
          node.left , value = del.(node.left)
        else
          node.right, value = del.(node.right)
        end
      end
      return node, value
    }
    
    @root, value = del.(@root)
    value
  end
end

ここで呼ばれている search_min メソッドは、それより下のすべての Node の中から、key が最小のものを探してくるものです。

使ってみます。root を削除してみます。

[13] pry(main)> t.delete(5)
=> :いちご
[14] pry(main)> t.print
 -- 10 -- 3 --
            -- 4
       --

ちゃんと規則を壊さぬように削除できています。

最大、最小の key を探索できます。二分探索なので、一般的に高速になります。

[15] pry(main)> t.minimum
=> [3, :みかん]

 
他にもまだいろいろなメソッドが定義してありますが、これくらいが最小限度の機能だと思います。
 

メソッド一覧

  • insert(key, value)
    • []= はこれを使って実装してあるが、insert の場合はブロックを取って、キーの重複があった場合の処理ができる。
  • []=(key, value)
    • insert とほぼ同じだが、key が重複した場合は上書きする。
  • search(key)
  • [](key)
  • delete(key)
  • preorder(key = nil, &bk)
    • 行きがけ。key を省略した場合は root からトラバースする。
  • inorder(key = nil, &bk)
    • 通りがけ。
  • postorder(key = nil, &bk)
    • 帰りがけ。
  • each(&bk)
    • inorder を使って実装してある。不都合なら再定義する。
  • breadth_first_search(key = nil, &bk)
  • bfs(key = nil, &bk)
  • minimum
  • maximum
  • inspect
  • right_rotate(key)
  • left_rotate(key)
  • parent_by_node(target)
  • parent(key)
  • add_hash(hash)
    • Hash を使って一気に insert する。
  • print