Scalaを使ってみる: (9) immutable MultiSetのメソッドを高速化する

このエントリーをはてなブックマークに追加

Scalaを勉強している.勉強中の身ではあるが,以下を例題として,Scalaプログラムの作り方について説明してみる.

テキストファイル中に現れる英単語の出現回数を数えて,出現回数の多い語から表示する.
入力のテキストファイルとしては,Project Gutenberg 中のHalmet を用いる(ファイル名を hamlet.txt,改行はLFにした).
なお利用している環境は,Ubuntu 8.04 LTS上の Scala 2.8.0 RC2 (2010年5月10日リリース), Java 1.6.0である.

高速化

前回のプログラムでは,多くのメソッドは iterator を用いて実装されている.たとえばマルチ集合の size は,iterator.size と同様の処理で求められている.
今回は,めぼしいメソッドについて,効率的な実装に書き換えてみる.

++ と -- メソッド

++ は union と,-- は diff と同様なので,以下のように定義する.

  def ++ (that: MS[A]) = union(that)
  def -- (that: MS[A]) = diff(that)

size と count メソッド

size は amap.values.reduceRight(_ + _) で良さそうだ.
あるいは amap.keys.map(amap(_)).sum でも良いような気がする.確かめて見よう.

  scala> val m = Map("a"->1, "b"->1)
  scala.collection.immutable.Map[java.lang.String,Int] = Map((a,1), (b,1))

  scala> m.keys.map(m(_)).sum
  Int = 1

なぜか 2 ではなくて,1 になってしまう.
これは m.keys が Set であるため, m.keys.map(m(_)) も値の Set になってしまい,重複要素が取り除かれるためである.

  scala> m.keys.map(m(_)).sum
  Iterable[Int] = Set(1)

同様の問題は,for comprehension でも生じるので,注意する必要がある.

  scala> for (x <- m.keys) yield m(x)
  Iterable[Int] = Set(1)

Map のキーを List にすれば,問題は生じない.

  scala> m.keys.toList.map(m(_)).sum
  Int = 2

あるいは keysIterator を利用するのでも良い.そこで size と count を以下のように定義する( keys は amap.keysIterator と定義している).

  override def size = keys.map(this(_)).sum
  override def count(p: A => Boolean) =
    keys.filter(p).map(this(_)).sum

find, exists, forall メソッド

find, exists, forall メソッドは,キーに対して同様の処理を行えば良い.

  override def find(p: A => Boolean) = keys.find(p)
  override def exists(p: A => Boolean) = keys.exists(p)
  override def forall(p: A => Boolean) = keys.forall(p)

max と min メソッド

max と min メソッドも,キーに対して同様の処理を行えば良い.

  override def max[B >: A](implicit cmp: Ordering[B]): A =
    keys.max(cmp)
  override def min[B >: A](implicit cmp: Ordering[B]): A =
    keys.min(cmp)

head と last メソッド

head や last メソッドも,キーに対する同様の処理で実装できる.

  override def head = amap.keys.head
  override def headOption = amap.keys.headOption
  override def tail = this - head
  override def last = amap.keys.last
  override def lastOption = amap.keys.lastOption
  override def init = this - last

マルチ集合のプログラム

まだ実装すべきメソッドは残っているが,以上をまとめると,以下のプログラムになる

import scala.collection._
import scala.collection.generic._
import scala.collection.mutable.Builder
import scala.collection.mutable.ListBuffer

case class MS[A] private(val amap: Map[A,Int]) extends Iterable[A]
with GenericTraversableTemplate[A,MS] with IterableLike[A,MS[A]]
with Addable[A,MS[A]] with Subtractable[A,MS[A]]  {
  override def companion = MS
  def iterator = amap.keysIterator.flatMap{
    elem => Iterator.fill[A](amap(elem))(elem)
  }
  def empty = MS.empty[A]
  def contains(elem: A) = amap.contains(elem)
  def apply(elem: A) = amap.getOrElse(elem, 0)
  def andThen[B](g: Int => B) = amap.andThen(g)
  def + (elem: A) = new MS(amap + (elem -> (this(elem)+1)))
  def - (elem: A) = this(elem) match {
    case 0 => this
    case 1 => new MS(amap - elem)
    case c => new MS(amap + (elem -> (c-1)))
  }
  def intersect(that: MS[A]) = {
    val m = collection.mutable.Map[A,Int]()
    for (x <- keys if that.contains(x))
      m += x -> (this(x) min that(x))
    new MS(m.toMap)
  }
  def & (that: MS[A]) = intersect(that)
  def union(that: MS[A]) = {
    val m = collection.mutable.Map[A,Int]() ++ amap
    for (x <- that.keys)
      m += x -> (this(x) + that(x))
    new MS(m.toMap)
  }
  def | (that: MS[A]) = union(that)
  def diff(that: MS[A]) = {
    val m = collection.mutable.Map[A,Int]() ++ amap
    for (x <- that.keys if contains(x)) {
      val c = this(x) - that(x)
      if (c > 0) m += x -> c else m -= x
    }
    new MS(m.toMap)
  }
  def &~ (that: MS[A]) = diff(that)
  def subsetOf(that: MS[A]) =
    keys.forall(x => this(x) <= that(x))
  def toMap = amap
  def toSet = amap.keySet
  def keys = amap.keysIterator

  def ++ (that: MS[A]) = union(that)
  def -- (that: MS[A]) = diff(that)

  override def size = keys.map(this(_)).sum
  override def count(p: A => Boolean) =
    keys.filter(p).map(this(_)).sum

  override def find(p: A => Boolean) = keys.find(p)
  override def exists(p: A => Boolean) = keys.exists(p)
  override def forall(p: A => Boolean) = keys.forall(p)

  override def max[B >: A](implicit cmp: Ordering[B]): A =
    keys.max(cmp)
  override def min[B >: A](implicit cmp: Ordering[B]): A =
    keys.min(cmp)

  override def head = amap.keys.head
  override def headOption = amap.keys.headOption
  override def tail = this - head
  override def last = amap.keys.last
  override def lastOption = amap.keys.lastOption
  override def init = this - last
}

object MS extends TraversableFactory[MS] {
  implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, MS[A]] =
    new GenericCanBuildFrom[A]
  def newBuilder[A] = new Builder[A,MS[A]] {
    private var amap = collection.mutable.Map[A,Int]()
    def clear = amap.clear
    def += (elem: A) = {
      amap += (elem -> (amap.getOrElse(elem,0)+1))
      this
    }
    def result = new MS(amap.toMap)
  }
}

実行時間の計測

まだ実装すべきメソッドは残っているが,一旦ここで実行時間を計測してみる.
まず,前回の(8) MultiSetのメソッドを定義する (immutable版)で作成したプログラムをコンパイルし実行してみる.
最初に乱数で,総要素数が10000のマルチ集合を作成する.

  scala> val rnd = new util.Random()
  rnd: scala.util.Random = scala.util.Random@1a8ecf4

  scala> val ms = MS[Int]() ++ (1 to 10000).map(i => rnd.nextInt(100))
  ms: MS[Int] = MS(45, 45, 45, 45, 45, ...)

次に計測のための関数を定義する.

  scala> def time(n: Int, cmd: => Any) = {
       |   val times = (new testing.Benchmark { def run=cmd }).runBenchmark(n)
       |   println(times)
       |   times.reduceLeft(_ + _).toDouble / n
       | }
  time: (n: Int,cmd: => Any)Double

(ms ++ (ms ++ (ms ++ ... (ms ++ ms) ...))) のように10個の ms を連結して最後にそのサイズをもとめる.

  scala> (1 to 10).map(i => ms).reduceRight(_ ++ _).size
  Int = 100000

これを20回実行し,平均の実行時間を求めると約400ミリ秒だった.

  scala> time(20, (1 to 10).map(i => ms).reduceRight(_ ++ _).size)
  List(569, 431, 432, 385, 382, 385, 382, 383, 387, 380, 377, 381, 380, 378, 380, 381, 378, 379, 379, 378)
  Double = 395.35

同様に,今回のプログラムをコンパイルし実行してみると約6ミリ秒だった.(JITの影響だろうか? 徐々に速くなっている).

  scala> time(20, (1 to 10).map(i => ms).reduceRight(_ ++ _).size)
  List(10, 14, 7, 7, 6, 7, 6, 6, 6, 6, 5, 5, 4, 3, 3, 3, 4, 3, 4, 4)
  Double = 5.65

次回は,残りのメソッドの実装を目指す.