Sunday, April 12, 2015

Using UnitTests to solve complex problems



Using UnitTests to solve complex problems
by  Maksim Kozyarchuk




The Problem

   I was asked to solve a problem the other day, needed to produce an answer within 5-10 minutes and found myself outmatched by snowballing complexity.   Here is the problem.

   Suppose you are able to equip cars traveling on a highway with a traffic sensor that can report when a car enters and clears traffic.   These sensors are very simple, they can only send two signals: enter_traffic and clear_traffic along with the signal, the car will report the mile marker where the signal was transmitted and a unique car identifier.  Furthermore, there are on and off ramps onto the highway, and as a rule the car exiting the highway will report clear_traffic if it was in traffic.   I was asked to come up with an algorithm to report upto date traffic information for a given road.   The algorithm had to be scalable and capable of supporting roads of arbitrary length.

   The words scalable stuck in my mind and for the next few minutes I attempted to find the right data structure capable of representing the data.   Realization that cars could enter and exit the highway in the middle of the traffic piled on complexity.  Additional complexity is introduced by the fact that traffic information ages and becomes explicit events.  Within minutes I felt outmatched and seeking a different approach.


Starting with UnitTests
   A more helpful approach was to explore the algorithm through simple unit tests.  Starting with handling a single car, to multiple cars. Then introducing overlapping and disjoint intervals, handling of entrances and exits and finally the notion of ticking clock that removes old traffic jams and cleans up unused data for more efficient access.  What came out surprised me.  Full algorithm was implemented in about 1 hour in under 60 lines of python code, with help of 16 simple unit tests.


class Road(object):
   def __init__(self, exits = [], entrances = []):
       self.exits = exits
       self.entrances = entrances

class Reading(object):
   def __init__(self, car_id, start, end, age):
       self.start_car_id = set()
       self.start_car_id.add( car_id )
       self.start = start
       self.end = end
       self.end_car_id = set()
       self.age = age
       
class Traffic(object):
   MAX_AGE = 10
      
   def __init__(self, road):
       self.road = road
       self.readings = []
       self.current_age = 1
       
   def increment_age(self):
       self.current_age +=1
       self.readings  = [r for r in self.readings if r.age > self.current_age - self.MAX_AGE]
       for r in self.readings:
           r.start_car_id = r.start_car_id - r.end_car_id
           r.end_car_id = set()
       
   def get_traffic(self):
       return sorted([(r.start, r.end) for r in self.readings])
   
   def enter_traffic(self, car_id, mile):
       for r in self.readings:
           if r.start <= mile <= r.end:
               if mile not in self.road.entrances:
                   r.start = mile
               r.start_car_id.add(car_id)
               r.age = self.current_age
               return
       self.readings.append(Reading(car_id, mile,mile+1,self.current_age))

   def clear_traffic(self, car_id, mile):
       updated = None
       for r in self.readings:
           if car_id in r.start_car_id  and car_id not in r.end_car_id :
               if mile in self.road.exits:
                   r.end = max(mile, r.end)
               else:
                   r.end = mile
               r.end_car_id.add(car_id)
               r.age = self.current_age
               updated = r
       if updated:
           self.readings = [updated] + [ r for r in self.readings if not (r.start <= updated.end <= r.end) ]
     


UnitTests

   Including UnitTests for completeness.



class TestTraffic(unittest.TestCase):

   def setUp(self):
       r = Road()
       self.t = Traffic(road = r)
       
   def testNoCars(self):
       self.assertEquals([],self.t.get_traffic())

   def testOneCarHitTraffic(self):

       self.t.enter_traffic(car_id = 1, mile = 1)
       self.assertEquals([(1,2)],self.t.get_traffic())
       
   def testOneCarHitAndClearsTraffic(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.assertEquals([(1,3)],self.t.get_traffic())
       
   def testSecondCarEntersWithOverlappingSignalsOverrides(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 2, mile = 2)
       self.assertEquals([(2,3)],self.t.get_traffic())
       
   def testSecondCarEntersWithOverlappingSignalsOverrides(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 2, mile = 2)
       self.t.clear_traffic(car_id = 2, mile = 4)
       self.assertEquals([(2,4)],self.t.get_traffic())

   def testSecondCarEntersWithNonOverlappingSignalsAdds(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 2, mile = 4)
       self.assertEquals([(1,3), (4,10)],self.t.get_traffic())

   def testSecondCarEntersWithNonOverlappingSignalsAdds(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 2, mile = 4)
       self.t.clear_traffic(car_id = 2, mile = 6)
       self.assertEquals([(1,3), (4,6)],self.t.get_traffic())

   def testThirdCarEntersWithOverlappingSignals(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 2, mile = 4)
       self.t.clear_traffic(car_id = 2, mile = 6)
       self.t.enter_traffic(car_id = 3, mile = 2)
       self.assertEquals([(2,3), (4,6)],self.t.get_traffic())

   def testThirdCarClearsWithOverlappingSignals(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 2, mile = 4)
       self.t.clear_traffic(car_id = 2, mile = 6)
       self.t.enter_traffic(car_id = 3, mile = 2)
       self.t.clear_traffic(car_id = 3, mile = 5)
       self.assertEquals([(2,5)],self.t.get_traffic())

   def testSameCarEntersAndClearsMultipletimes(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.t.enter_traffic(car_id = 1, mile = 4)
       self.t.clear_traffic(car_id = 1, mile = 6)
       self.assertEquals( [(1,3), (4,6)], self.t.get_traffic() )

   def testStopAndGoTrafficAroundSameArea(self):
       self.t.enter_traffic(car_id = 1, mile = 3)
       self.t.clear_traffic(car_id = 1, mile = 5)
       self.t.enter_traffic(car_id = 2, mile = 4)
       self.t.enter_traffic(car_id = 3, mile = 2)
       self.assertEquals( [(2,3), (4,5)], self.t.get_traffic() )
       self.t.clear_traffic(car_id = 2, mile = 6)
       self.t.clear_traffic(car_id = 3, mile = 5)
       self.assertEquals( [(2,5)], self.t.get_traffic() )
       
   def testTraficAgesOff(self):
       self.t.MAX_AGE = 3
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 3)
       self.assertEquals( [(1,3)], self.t.get_traffic() )
       self.t.increment_age()
       self.assertEquals( [(1,3)], self.t.get_traffic() )
       self.t.increment_age()
       self.assertEquals( [(1,3)], self.t.get_traffic() )
       self.t.increment_age()
       self.assertEquals( [], self.t.get_traffic() )
   
   def testIncrementAgeRemovesPairedCars(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 6)
       self.t.enter_traffic(car_id = 2, mile = 2)
       self.t.clear_traffic(car_id = 2, mile = 7)
       self.t.enter_traffic(car_id = 3, mile = 2)
       self.t.clear_traffic(car_id = 3, mile = 7)
       self.t.enter_traffic(car_id = 4, mile = 2)
       self.t.enter_traffic(car_id = 5, mile = 2)
       self.t.increment_age()
       self.assertEquals( set([4,5]), self.t.readings[0].start_car_id )
       self.assertEquals( set(), self.t.readings[0].end_car_id )
       
class TestTrafficWithExits(unittest.TestCase):

   def setUp(self):
       r = Road(exits = [10,20,30,40,60], entrances = [5,15,35,45,55])
       self.t = Traffic(road = r)
       
   def testEnterAtEntranceDoesNotOverrideStart(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 13)
       self.t.enter_traffic(car_id = 2, mile = 5)
       self.assertEquals( [(1,13)], self.t.get_traffic() )

   def testEnterAtEntranceAndClearsTrafficFirst(self):
       self.t.enter_traffic(car_id = 1, mile = 4)
       self.t.enter_traffic(car_id = 2, mile = 5)
       self.t.clear_traffic(car_id = 2, mile = 13)
       self.assertEquals( [(4,13)], self.t.get_traffic() )
       self.t.clear_traffic(car_id = 1, mile = 12)
       self.assertEquals( [(4,12)], self.t.get_traffic() )

   def testEnterAtEntranceAndClearsTrafficFirstWithDistance(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.enter_traffic(car_id = 2, mile = 5)
       self.t.clear_traffic(car_id = 2, mile = 13)
       self.assertEquals( [ (1,2) , (5,13)], self.t.get_traffic() )
       self.t.clear_traffic(car_id = 1, mile = 12)
       self.assertEquals( [ (1,12) ], self.t.get_traffic() )        
       
   def testClearAtExitIncreasesTraffic(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 10)
       self.assertEquals( [(1,10)], self.t.get_traffic() )

   def testClearAtExitDoesNotDecreeseTraffic(self):
       self.t.enter_traffic(car_id = 1, mile = 1)
       self.t.clear_traffic(car_id = 1, mile = 12)
       self.t.enter_traffic(car_id = 2, mile = 1)
       self.t.clear_traffic(car_id = 2, mile = 10)
       self.assertEquals( [(1,12)], self.t.get_traffic() )





No comments: