Scalaを使ってみる: (9) immutable MultiSetのメソッドを高速化する
Scalaを勉強している.勉強中の身ではあるが,以下を例題として,Scalaプログラムの作り方について説明してみる.
テキストファイル中に現れる英単語の出現回数を数えて,出現回数の多い語から表示する.入力のテキストファイルとしては,Project Gutenberg 中のHalmet を用いる(ファイル名を hamlet.txt,改行はLFにした).
なお利用している環境は,Ubuntu 8.04 LTS上の Scala 2.8.0 RC2 (2010年5月10日リリース), Java 1.6.0である.
高速化
前回のプログラムでは,多くのメソッドは iterator を用いて実装されている.たとえばマルチ集合の size は,iterator.size と同様の処理で求められている.
今回は,めぼしいメソッドについて,効率的な実装に書き換えてみる.
++ と -- メソッド
++ は union と,-- は diff と同様なので,以下のように定義する.
def ++ (that: MS[A]) = union(that) def -- (that: MS[A]) = diff(that)
size と count メソッド
size は amap.values.reduceRight(_ + _) で良さそうだ.
あるいは amap.keys.map(amap(_)).sum でも良いような気がする.確かめて見よう.
scala> val m = Map("a"->1, "b"->1) scala.collection.immutable.Map[java.lang.String,Int] = Map((a,1), (b,1)) scala> m.keys.map(m(_)).sum Int = 1
なぜか 2 ではなくて,1 になってしまう.
これは m.keys が Set であるため, m.keys.map(m(_)) も値の Set になってしまい,重複要素が取り除かれるためである.
scala> m.keys.map(m(_)).sum
Iterable[Int] = Set(1)
同様の問題は,for comprehension でも生じるので,注意する必要がある.
scala> for (x <- m.keys) yield m(x) Iterable[Int] = Set(1)
Map のキーを List にすれば,問題は生じない.
scala> m.keys.toList.map(m(_)).sum
Int = 2
あるいは keysIterator を利用するのでも良い.そこで size と count を以下のように定義する( keys は amap.keysIterator と定義している).
override def size = keys.map(this(_)).sum override def count(p: A => Boolean) = keys.filter(p).map(this(_)).sum
find, exists, forall メソッド
find, exists, forall メソッドは,キーに対して同様の処理を行えば良い.
override def find(p: A => Boolean) = keys.find(p) override def exists(p: A => Boolean) = keys.exists(p) override def forall(p: A => Boolean) = keys.forall(p)
max と min メソッド
max と min メソッドも,キーに対して同様の処理を行えば良い.
override def max[B >: A](implicit cmp: Ordering[B]): A = keys.max(cmp) override def min[B >: A](implicit cmp: Ordering[B]): A = keys.min(cmp)
head と last メソッド
head や last メソッドも,キーに対する同様の処理で実装できる.
override def head = amap.keys.head override def headOption = amap.keys.headOption override def tail = this - head override def last = amap.keys.last override def lastOption = amap.keys.lastOption override def init = this - last
マルチ集合のプログラム
まだ実装すべきメソッドは残っているが,以上をまとめると,以下のプログラムになる
import scala.collection._ import scala.collection.generic._ import scala.collection.mutable.Builder import scala.collection.mutable.ListBuffer case class MS[A] private(val amap: Map[A,Int]) extends Iterable[A] with GenericTraversableTemplate[A,MS] with IterableLike[A,MS[A]] with Addable[A,MS[A]] with Subtractable[A,MS[A]] { override def companion = MS def iterator = amap.keysIterator.flatMap{ elem => Iterator.fill[A](amap(elem))(elem) } def empty = MS.empty[A] def contains(elem: A) = amap.contains(elem) def apply(elem: A) = amap.getOrElse(elem, 0) def andThen[B](g: Int => B) = amap.andThen(g) def + (elem: A) = new MS(amap + (elem -> (this(elem)+1))) def - (elem: A) = this(elem) match { case 0 => this case 1 => new MS(amap - elem) case c => new MS(amap + (elem -> (c-1))) } def intersect(that: MS[A]) = { val m = collection.mutable.Map[A,Int]() for (x <- keys if that.contains(x)) m += x -> (this(x) min that(x)) new MS(m.toMap) } def & (that: MS[A]) = intersect(that) def union(that: MS[A]) = { val m = collection.mutable.Map[A,Int]() ++ amap for (x <- that.keys) m += x -> (this(x) + that(x)) new MS(m.toMap) } def | (that: MS[A]) = union(that) def diff(that: MS[A]) = { val m = collection.mutable.Map[A,Int]() ++ amap for (x <- that.keys if contains(x)) { val c = this(x) - that(x) if (c > 0) m += x -> c else m -= x } new MS(m.toMap) } def &~ (that: MS[A]) = diff(that) def subsetOf(that: MS[A]) = keys.forall(x => this(x) <= that(x)) def toMap = amap def toSet = amap.keySet def keys = amap.keysIterator def ++ (that: MS[A]) = union(that) def -- (that: MS[A]) = diff(that) override def size = keys.map(this(_)).sum override def count(p: A => Boolean) = keys.filter(p).map(this(_)).sum override def find(p: A => Boolean) = keys.find(p) override def exists(p: A => Boolean) = keys.exists(p) override def forall(p: A => Boolean) = keys.forall(p) override def max[B >: A](implicit cmp: Ordering[B]): A = keys.max(cmp) override def min[B >: A](implicit cmp: Ordering[B]): A = keys.min(cmp) override def head = amap.keys.head override def headOption = amap.keys.headOption override def tail = this - head override def last = amap.keys.last override def lastOption = amap.keys.lastOption override def init = this - last } object MS extends TraversableFactory[MS] { implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, MS[A]] = new GenericCanBuildFrom[A] def newBuilder[A] = new Builder[A,MS[A]] { private var amap = collection.mutable.Map[A,Int]() def clear = amap.clear def += (elem: A) = { amap += (elem -> (amap.getOrElse(elem,0)+1)) this } def result = new MS(amap.toMap) } }
実行時間の計測
まだ実装すべきメソッドは残っているが,一旦ここで実行時間を計測してみる.
まず,前回の(8) MultiSetのメソッドを定義する (immutable版)で作成したプログラムをコンパイルし実行してみる.
最初に乱数で,総要素数が10000のマルチ集合を作成する.
scala> val rnd = new util.Random() rnd: scala.util.Random = scala.util.Random@1a8ecf4 scala> val ms = MS[Int]() ++ (1 to 10000).map(i => rnd.nextInt(100)) ms: MS[Int] = MS(45, 45, 45, 45, 45, ...)
次に計測のための関数を定義する.
scala> def time(n: Int, cmd: => Any) = { | val times = (new testing.Benchmark { def run=cmd }).runBenchmark(n) | println(times) | times.reduceLeft(_ + _).toDouble / n | } time: (n: Int,cmd: => Any)Double
(ms ++ (ms ++ (ms ++ ... (ms ++ ms) ...))) のように10個の ms を連結して最後にそのサイズをもとめる.
scala> (1 to 10).map(i => ms).reduceRight(_ ++ _).size Int = 100000
これを20回実行し,平均の実行時間を求めると約400ミリ秒だった.
scala> time(20, (1 to 10).map(i => ms).reduceRight(_ ++ _).size) List(569, 431, 432, 385, 382, 385, 382, 383, 387, 380, 377, 381, 380, 378, 380, 381, 378, 379, 379, 378) Double = 395.35
同様に,今回のプログラムをコンパイルし実行してみると約6ミリ秒だった.(JITの影響だろうか? 徐々に速くなっている).
scala> time(20, (1 to 10).map(i => ms).reduceRight(_ ++ _).size) List(10, 14, 7, 7, 6, 7, 6, 6, 6, 6, 5, 5, 4, 3, 3, 3, 4, 3, 4, 4) Double = 5.65
次回は,残りのメソッドの実装を目指す.
「Scalaを使ってみる」の目次
- (1) ファイルからの入力
- (2) 英単語の抽出
- (3) 出現回数を数える (mutable版)
- (4) Mapのメソッド
- (5) プログラム作成 (mutable版)
- (6) MultiSetを定義する (mutable版)
- (7) immutable MultiSetを定義する
- (8) immutable MultiSetのメソッドを定義する
- (9) immutable MultiSetのメソッドを高速化する
- (10) immutable MultiSetのメソッドを高速化する (続き)
- (11) immutable MultiSetのメソッドをリファクタリング
- (12) Martin Oderskyによるオンライン授業
- (13) Martin Oderskyによるオンライン授業 (第2回)