Scala Puzzlers, part 2
This puzzler is very strange. I have not completely figured out what causes the described performance problem yet, but I guess that it is a bug in the JVM, or in Scala compiler, or both.
I tried myself at Google Code Jam Qualification round this year and there was a nice problem called Fly Swatter (unfortunately there's no direct link to the problem description, but you can find it by clicking the “Qualification Round” link on the GCJ Contest page). It's fairly simple and it didn't take long to write a working solution in Scala. My solution produced correct results for the small test dataset, then I downloaded the large dataset (once downloaded one has 8 minutes to submit the results back), and…
Performance problem puzzler
… I figured out that my program was horribly slow. I didn't expect the program to work that slow: during the 8 minutes I'd had less than 30 (of 99) results were produced. Only after 32 minutes I got the results, alas, too late. I knew that in the worst case I'd get something like 250000 iterations for each test case, where each iteration computed couple of arcsines and several square roots, but hey, computers are fast these days, and I knew I didn't screw up with the algorithm (there were some obvious ways of improvement, but I figured that it was not necessary to use them — the performance should have been fine without them).
Now, here's the deal: my slow Scala program follows at the bottom of the post (you can also download it). Try to figure out what's wrong with it and how you can improve it's speed by two orders of magnitude (I'm not kidding you — it takes 8.5 seconds to run now on the same laptop). One caveat though: run the compiled program using Java HotSpot Server VM (it is used by default on my laptop, and only the Server VM shows the problem, not the Client VM), e.g.:
% time ( cat C-large.in | java -server -cp .:$SCALA_HOME/lib/scala-library.jar FlySwatter )
Input data should be piped to standard input, use the C-large.in dataset to test the program.
Spoiler
If you're Martin Odersky (and you want to check if it is a bug in Scala compiler) or if you're just too lazy to spend your time to inspect the code, profile, get puzzled because you don't get any meaningful results from the profiling, and then give up, take a look at the fixed program (there's a comment at the top which describes what has been done). Prepare to be surprised!
Slow version
/* * This is a slow version of the FlySwatter program. * It performs poorly on Server VM. * Try to find out why! */ import java.io._ object FlySwatter extends Application { case class Square(x1: Double, y1: Double, side: Double) { val x2 = x1 + side val y2 = y1 + side def area = side*side def isValid = side > 0 def contractBy(f: Double): Square = { if (2*f >= side) { Square(x1 + side/2, y1 + side/2, 0) } else { Square(x1 + f, y1 + f, side - 2*f) } } } case class Case(f: Double, rr: Double, t: Double, r: Double, g: Double) def parseData: Seq[Case] = { val in = new BufferedReader(new InputStreamReader(System.in)) def readCase: Case = { def readWords: Seq[String] = in.readLine.split(' ').filter(_.length > 0) val caseData = readWords.map(_.toDouble) Case(caseData(0), caseData(1), caseData(2), caseData(3), caseData(4)) } val numCases = in.readLine.toInt val cases = for (i <- (1 to numCases).force) yield readCase cases } def calculateProbability(c: Case): Double = { def pointInCircle(x: Double, y: Double, r: Double): Boolean = { x*x + y*y < r*r } def squareAndCircleIntersectionArea(s: Square, r: Double): Double = { // calculates the definite integral from a to b of sqrt(r^2-x^2) dx def integrateCirclePart(a: Double, b: Double, r: Double): Double = { def i(x: Double): Double = x*Math.sqrt(r*r - x*x)/2 + r*r/2 * Math.asin(x/r) i(b)-i(a) } assert(s.x1 > 0 && s.y1 > 0) assert(r > 0) if (s.isValid) { val lowerLeftInCircle = pointInCircle(s.x1, s.y1, r) // using lazy vals for these four booleans val upperLeftInCircle = pointInCircle(s.x1, s.y2, r) // makes the program run ~25% slower, val lowerRightInCircle = pointInCircle(s.x2, s.y1, r) // however real performance hit is probably lower val upperRightInCircle = pointInCircle(s.x2, s.y2, r) // because the measurement precision is quite low val area: Double = (lowerLeftInCircle, upperLeftInCircle, lowerRightInCircle, upperRightInCircle) match { case (false, _, _, _) => 0 // square is outside of the circle case (true, _, _, true) => s.area // whole square is in the circle case (true, false, false, false) => // single corner is in the circle val x = Math.sqrt(r*r - s.y1*s.y1) integrateCirclePart(s.x1, x, r) - (x-s.x1)*s.y1 case (true, true, false, _) => // left side of the square is in the circle integrateCirclePart(s.y1, s.y2, r) - (s.y2-s.y1)*s.x1 case (true, false, true, _) => // bottom side of the square is in the circle integrateCirclePart(s.x1, s.x2, r) - (s.x2-s.x1)*s.y1 case (true, true, true, false) => // 3 corners of the square are in the circle val x = Math.sqrt(r*r - s.y2*s.y2) (x-s.x1)*(s.y2-s.y1) + integrateCirclePart(x, s.x2, r) - (s.x2-x)*s.y1 } assume(area >= 0) area } else { 0 // empty square } } assert(c.g + 2*c.r > 0) val maxI = Math.ceil((c.rr - c.t - c.r)/(c.g + 2*c.r)).toInt val squares = for { i <- 0 until maxI; j <- 0 until maxI // until method creates a lazy list, thus we get a lazy list of Squares in the end x1 = c.r + i*(c.g + 2*c.r) y1 = c.r + j*(c.g + 2*c.r) } yield Square(x1, y1, c.g).contractBy(c.f) val squaresAndCircleIntersectionArea = (0.0 /: squares.map(squareAndCircleIntersectionArea(_, c.rr - c.t - c.f)))((a: Double, b: Double) => a+b) val flyIsSafeProb = squaresAndCircleIntersectionArea / (Math.Pi * c.rr * c.rr / 4) 1-flyIsSafeProb } val cases = parseData for (i <- 0 until cases.length) { val prob = calculateProbability(cases(i)) System.out.printf("Case #%1$d: %2$.6f\n", Array[Object](new java.lang.Integer(i+1), new java.lang.Double(prob))) } }
Enjoy!