« Scala puzzlers, part... | Main | Making anti-aliasing... »

Scala Puzzlers, part 2

Jul 24 2008, 14:51 MSD |  [  Scala  ]

This puzzler is very strange. I have not completely figured out what causes the described performance problem yet, but I guess that it is a bug in the JVM, or in Scala compiler, or both.

I tried myself at Google Code Jam Qualification round this year and there was a nice problem called Fly Swatter (unfortunately there's no direct link to the problem description, but you can find it by clicking the “Qualification Round” link on the GCJ Contest page). It's fairly simple and it didn't take long to write a working solution in Scala. My solution produced correct results for the small test dataset, then I downloaded the large dataset (once downloaded one has 8 minutes to submit the results back), and…

Performance problem puzzler

… I figured out that my program was horribly slow. I didn't expect the program to work that slow: during the 8 minutes I'd had less than 30 (of 99) results were produced. Only after 32 minutes I got the results, alas, too late. I knew that in the worst case I'd get something like 250000 iterations for each test case, where each iteration computed couple of arcsines and several square roots, but hey, computers are fast these days, and I knew I didn't screw up with the algorithm (there were some obvious ways of improvement, but I figured that it was not necessary to use them — the performance should have been fine without them).

Now, here's the deal: my slow Scala program follows at the bottom of the post (you can also download it). Try to figure out what's wrong with it and how you can improve it's speed by two orders of magnitude (I'm not kidding you — it takes 8.5 seconds to run now on the same laptop). One caveat though: run the compiled program using Java HotSpot Server VM (it is used by default on my laptop, and only the Server VM shows the problem, not the Client VM), e.g.:

% time ( cat C-large.in | java -server -cp .:$SCALA_HOME/lib/scala-library.jar FlySwatter )

Input data should be piped to standard input, use the C-large.in dataset to test the program.

Spoiler

If you're Martin Odersky (and you want to check if it is a bug in Scala compiler) or if you're just too lazy to spend your time to inspect the code, profile, get puzzled because you don't get any meaningful results from the profiling, and then give up, take a look at the fixed program (there's a comment at the top which describes what has been done). Prepare to be surprised!

Slow version

/*
 * This is a slow version of the FlySwatter program.
 * It performs poorly on Server VM.
 * Try to find out why!
 */
import java.io._

object FlySwatter extends Application {
  case class Square(x1: Double, y1: Double, side: Double) {
    val x2 = x1 + side
    val y2 = y1 + side

    def area = side*side

    def isValid = side > 0

    def contractBy(f: Double): Square = {
      if (2*f >= side) {
        Square(x1 + side/2, y1 + side/2, 0)
      } else {
        Square(x1 + f, y1 + f, side - 2*f)
      }
    }
  }

  case class Case(f: Double, rr: Double, t: Double, r: Double, g: Double)

  def parseData: Seq[Case] = {
    val in = new BufferedReader(new InputStreamReader(System.in))
    def readCase: Case = {
      def readWords: Seq[String] = in.readLine.split(' ').filter(_.length > 0)

      val caseData = readWords.map(_.toDouble)
      Case(caseData(0), caseData(1), caseData(2), caseData(3), caseData(4))
    }

    val numCases = in.readLine.toInt
    val cases = for (i <- (1 to numCases).force) yield readCase
    cases
  }

  def calculateProbability(c: Case): Double = {
    def pointInCircle(x: Double, y: Double, r: Double): Boolean = {
      x*x + y*y < r*r
    }
    def squareAndCircleIntersectionArea(s: Square, r: Double): Double = {
      // calculates the definite integral from a to b of sqrt(r^2-x^2) dx
      def integrateCirclePart(a: Double, b: Double, r: Double): Double = {
        def i(x: Double): Double = x*Math.sqrt(r*r - x*x)/2 + r*r/2 * Math.asin(x/r)
        i(b)-i(a)
      }
      assert(s.x1 > 0 && s.y1 > 0)
      assert(r > 0)

      if (s.isValid) {
        val lowerLeftInCircle = pointInCircle(s.x1, s.y1, r)    // using lazy vals for these four booleans
        val upperLeftInCircle = pointInCircle(s.x1, s.y2, r)    // makes the program run ~25% slower,
        val lowerRightInCircle = pointInCircle(s.x2, s.y1, r)   // however real performance hit is probably lower
        val upperRightInCircle = pointInCircle(s.x2, s.y2, r)   // because the measurement precision is quite low

        val area: Double = (lowerLeftInCircle, upperLeftInCircle,
                    lowerRightInCircle, upperRightInCircle) match {
          case (false, _, _, _) => 0                                // square is outside of the circle
          case (true, _, _, true) => s.area                         // whole square is in the circle
          case (true, false, false, false) =>                       // single corner is in the circle
            val x = Math.sqrt(r*r - s.y1*s.y1)
            integrateCirclePart(s.x1, x, r) - (x-s.x1)*s.y1
          case (true, true, false, _) =>                            // left side of the square is in the circle
            integrateCirclePart(s.y1, s.y2, r) - (s.y2-s.y1)*s.x1
          case (true, false, true, _) =>                            // bottom side of the square is in the circle
            integrateCirclePart(s.x1, s.x2, r) - (s.x2-s.x1)*s.y1
          case (true, true, true, false) =>                         // 3 corners of the square are in the circle
            val x = Math.sqrt(r*r - s.y2*s.y2)
            (x-s.x1)*(s.y2-s.y1) + integrateCirclePart(x, s.x2, r) - (s.x2-x)*s.y1
        }
        assume(area >= 0)
        area
      } else {
        0 // empty square
      }
    }

    assert(c.g + 2*c.r > 0)
    val maxI = Math.ceil((c.rr - c.t - c.r)/(c.g + 2*c.r)).toInt
    val squares = for {
      i <- 0 until maxI; j <- 0 until maxI // until method creates a lazy list, thus we get a lazy list of Squares in the end
      x1 = c.r + i*(c.g + 2*c.r)
      y1 = c.r + j*(c.g + 2*c.r)
    } yield Square(x1, y1, c.g).contractBy(c.f)

    val squaresAndCircleIntersectionArea =
      (0.0 /: squares.map(squareAndCircleIntersectionArea(_, c.rr - c.t - c.f)))((a: Double, b: Double) => a+b)

    val flyIsSafeProb = squaresAndCircleIntersectionArea / (Math.Pi * c.rr * c.rr / 4)
    1-flyIsSafeProb
  }

  val cases = parseData
  for (i <- 0 until cases.length) {

    val prob = calculateProbability(cases(i))
    System.out.printf("Case #%1$d: %2$.6f\n", Array[Object](new java.lang.Integer(i+1), new java.lang.Double(prob)))
  }
}

Enjoy!

Comments:

If you add -XX:+PrintCompilation to your launcher, you will notice that no methods are compiled by the HotSpot JIT for the broken version. This means that it runs purely in the interpreter and as such it's horribly slow. The fixed version does not have the same issue.

My theory is that code that runs in the class initialiser cannot be JITed or something similar and that is the source of the problem.

Posted by Ismael Juma on July 24, 2008 at 04:57 PM MSD #

In your second version, every variable defined in "main" is obviously function-local variable. But I wonder if, in the first version, the Scala compiler turns a couple of those into fields (and maybe tries to ensure safety with "synchronized" or something).

Have you tried making the following change to your second program:
- rename function "main" to "main2"
- do "FlySwatter extends Application"
- at the bottom, call "main2()"

I'm currently not on my own machine, so I don't have a Scala environment handy to try this out.

Try posting this to the Scala mailing list.

Posted by Kannan Goundan on July 24, 2008 at 06:04 PM MSD #

Ismael, that is true, but there is something else there, here are timings for various options:

* without any options: 845.92
* -Xint: 440.11 (!)
* -Xbatch: 807.78

Even if it is being run interpreted only, there's something else going on, otherwise it wouldn't be twice as fast in the interpreted mode.

Posted by Ivan Tarasov on July 24, 2008 at 06:08 PM MSD #

I was going to reply to the comment above, but I then noticed that Ivan posted something similar to the Scala mailing list, so I posted the reply there. Probably easier to continue the discussion in one place. It can easily be followed through Gmane (posts can also be done that way without subscribing). The current thread:

http://thread.gmane.org/gmane.comp.lang.scala/12795

The initial thread:

http://thread.gmane.org/gmane.comp.lang.scala/12790

Ismael

Posted by Ismael Juma on July 24, 2008 at 07:26 PM MSD #

Hello Ivan,

I don't have an exotic puzzle like your yet, but I did came across this:

object Demo extends Application {
def exitWith(s: String) = println(s); exit(1)

def main(args: Array[String]){
println("demo")
}
}

I can compile, but see no output. It took myself sometime to realize what stupid mistake I made with multi statements and trying to save parenthesis and lines. So yeah some convention will goes a long way in less frustration I guess. :)

Have fun!

Posted by Zemian Deng on July 24, 2008 at 09:34 PM MSD #

Actually, you can remove the "extends Application" part, as I forgot to "override" the main.

Posted by Zemian Deng on July 24, 2008 at 09:36 PM MSD #

Post a Comment:
  • HTML Syntax: NOT allowed