Scalaを使ってみる: (11) immutable MultiSetのメソッドをリファクタリング

このエントリーをはてなブックマークに追加

Scalaを勉強している.勉強中の身ではあるが,以下を例題として,Scalaプログラムの作り方について説明してみる.

テキストファイル中に現れる英単語の出現回数を数えて,出現回数の多い語から表示する.
入力のテキストファイルとしては,Project Gutenberg 中のHalmet を用いる(ファイル名を hamlet.txt,改行はLFにした).
なお利用している環境は,Ubuntu 8.04 LTS上の Scala 2.8.0 RC2 (2010年5月10日リリース), Java 1.6.0である.

newMS

intersect, union, diff, filter, map, collect, flatMap, flatten などは,いずれも新しいマルチ集合を生成する.
たとえば,filter は以下のように定義した.

  override def filter(p: A => Boolean) = {
    val m = collection.mutable.Map[A,Int]()
    for (x <- keys if p(x))
      m += x -> this(x)
    new MS(m.toMap)
  }

これらを統一的に記述するため,以下のような newMS メソッドを定義する.

  private def newMS[B](xvs: TraversableOnce[(B,Int)]) = {
    val m = collection.mutable.Map[B,Int]()
    xvs.foreach {
      case (elem, c) => {
	val c1 = m.getOrElse(elem,0) + c
	if (c1 > 0) m += elem -> c1 else m -= elem
      }
    }
    new MS(m.toMap)
  }

newMS は,型 B と Int の組のコレクションから型 B の新しいマルチ集合を生成する.コレクション中に同一のキーが複数現れている場合は,値が合計される.値が負の場合も和が計算され,合計が負ならば,キーが削除される.

intersectメソッド

intersectは,以下のように記述できる.

  def intersect(that: MS[A]) =
    newMS(keys.filter(that.contains(_)).map(x => (x, this(x) min that(x))))

unionメソッド

unionは,以下のように記述できる.

  def union(that: MS[A]) =
    newMS(keys.map(x => (x, this(x))) ++ that.keys.map(x => (x, that(x))))

diffメソッド

diffは,以下のように記述できる.

  def diff(that: MS[A]) =
    newMS(keys.map(x => (x, this(x))) ++ that.keys.map(x => (x, -that(x))))

filterメソッド

filterは,以下のように記述できる.

  override def filter(p: A => Boolean) =
    newMS(keys.filter(p).map(x => (x, this(x))))

mapメソッド

mapは,以下のように記述できる.

  def map[B](f: A => B): MS[B] =
    newMS(keys.map(x => (f(x), this(x))))

collectメソッド

collectは,以下のように記述できる.

  def collect[B](pf: PartialFunction[A,B]): MS[B] =
    newMS(keys.filter(pf.isDefinedAt(_)).map(x => (pf(x), this(x))))

flatMapメソッド

flatMapは,以下のように記述できる.

  def flatMap[B](f: A => Traversable[B]): MS[B] =
    newMS(keys.flatMap(x => f(x).toIterator.map(y => (y, this(x)))))

flattenメソッド

flattenは,以下のように記述できる.

  override def flatten[B](implicit asTraversable: A => Traversable[B]): MS[B] =
    newMS(keys.flatMap(x => asTraversable(x).toIterator.map(y => (y, this(x)))))

リファクタリングしたプログラム

リファクタリングした結果は,以下のようになる.

import scala.collection._
import scala.collection.generic._
import scala.collection.mutable.Builder

case class MS[A] private(val amap: Map[A,Int]) extends Iterable[A]
with GenericTraversableTemplate[A,MS] with IterableLike[A,MS[A]]
with Addable[A,MS[A]] with Subtractable[A,MS[A]]  {
  override def companion = MS
  def iterator = amap.keysIterator.flatMap{
    elem => Iterator.fill[A](amap(elem))(elem)
  }
  def empty = MS.empty[A]
  def contains(elem: A) = amap.contains(elem)
  def apply(elem: A) = amap.getOrElse(elem, 0)
  def andThen[B](g: Int => B) = amap.andThen(g)
  def + (elem: A) = new MS(amap + (elem -> (this(elem)+1)))
  def - (elem: A) = this(elem) match {
    case 0 => this
    case 1 => new MS(amap - elem)
    case c => new MS(amap + (elem -> (c-1)))
  }
  private def newMS[B](xvs: TraversableOnce[(B,Int)]) = {
    val m = collection.mutable.Map[B,Int]()
    xvs.foreach {
      case (elem, c) => {
	val c1 = m.getOrElse(elem,0) + c
	if (c1 > 0) m += elem -> c1 else m -= elem
      }
    }
    new MS(m.toMap)
  }

  def intersect(that: MS[A]) =
    newMS(keys.filter(that.contains(_)).map(x => (x, this(x) min that(x))))
  def & (that: MS[A]) = intersect(that)
  def union(that: MS[A]) =
    newMS(keys.map(x => (x, this(x))) ++ that.keys.map(x => (x, that(x))))
  def | (that: MS[A]) = union(that)
  def diff(that: MS[A]) =
    newMS(keys.map(x => (x, this(x))) ++ that.keys.map(x => (x, -that(x))))
  def &~ (that: MS[A]) = diff(that)
  def subsetOf(that: MS[A]) =
    keys.forall(x => this(x) <= that(x))
  def toMap = amap
  def toSet = amap.keySet
  def keys = amap.keysIterator

  def ++ (that: MS[A]) = union(that)
  def -- (that: MS[A]) = diff(that)

  override def size = keys.map(this(_)).sum
  override def count(p: A => Boolean) =
    keys.filter(p).map(this(_)).sum

  override def find(p: A => Boolean) = keys.find(p)
  override def exists(p: A => Boolean) = keys.exists(p)
  override def forall(p: A => Boolean) = keys.forall(p)

  override def max[B >: A](implicit cmp: Ordering[B]): A =
    keys.max(cmp)
  override def min[B >: A](implicit cmp: Ordering[B]): A =
    keys.min(cmp)

  override def head = amap.keys.head
  override def headOption = amap.keys.headOption
  override def tail = this - head
  override def last = amap.keys.last
  override def lastOption = amap.keys.lastOption
  override def init = this - last

  override def filter(p: A => Boolean) =
    newMS(keys.filter(p).map(x => (x, this(x))))
  override def filterNot(p: A => Boolean) = filter(! p(_))
  def map[B](f: A => B): MS[B] =
    newMS(keys.map(x => (f(x), this(x))))
  def collect[B](pf: PartialFunction[A,B]): MS[B] =
    newMS(keys.filter(pf.isDefinedAt(_)).map(x => (pf(x), this(x))))
  def flatMap[B](f: A => Traversable[B]): MS[B] =
    newMS(keys.flatMap(x => f(x).toIterator.map(y => (y, this(x)))))
  override def flatten[B](implicit asTraversable: A => Traversable[B]): MS[B] =
    newMS(keys.flatMap(x => asTraversable(x).toIterator.map(y => (y, this(x)))))
}

object MS extends TraversableFactory[MS] {
  implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, MS[A]] =
    new GenericCanBuildFrom[A]
  def newBuilder[A] = new Builder[A,MS[A]] {
    private var amap = collection.mutable.Map[A,Int]()
    def clear = amap.clear
    def += (elem: A) = {
      amap += (elem -> (amap.getOrElse(elem,0)+1))
      this
    }
    def result = new MS(amap.toMap)
  }
}

一応,これで完成としよう.