Search code examples
c++design-patternsswitch-statementstate-machineidioms

C++ code for state machine


This was an interview question to be coded in C++:

Write code for a vending machine: Start with a simple one where it just vends one type of item. So two state variables: money and inventory, would do.

My answer:

I would use a state machine which has about 3-4 states. Use an enum variable to indicate the state and use a switch case statement, where each case has the operations to be done corresponding to each state and stay in a loop to move from one state to another.

The next question:

But using a switch case statement does not "scale well" for more states being added and modifying existing operations in a state. How are you going to deal with that problem?

I couldn't answer this question at that time. But later thought, I can probably:

  • have different functions for different states (each function corresponding to a state)
  • have an std::map from (string, function) where string indicates state to call the corresponding state function.
  • The main function has a string variable (starting in initial state), and calls the function corresponding to that variable in a loop. Each function does the operations needed and returns the new state to the main function.

My questions are:

  • What is the problem with switch-case statements with respect to scalability in the context of large scale software systems?
  • If so is my solution (which currently I feel is a bit more modular than having long linear code) going to resolve the problem?

The interview question is expecting answers from C++ idioms and design patterns for large scale software systems.


Solution

  • I was thinking in a more OO approach, using the State Pattern:

    The Machine:

    // machine.h
    #pragma once
    
    #include "MachineStates.h"
    
    class AbstractState;
    
    class Machine {
      friend class AbstractState;
    
    public:
      Machine(unsigned int _stock);
      void sell(unsigned int quantity);
      void refill(unsigned int quantity);
      unsigned int getStock();
      ~Machine();
    
    private:
      unsigned int stock;
      AbstractState *state;
    };
    
    
    // --------
    
    // machine.cpp
    #include "Machine.h"
    #include "MachineStates.h"
    
    Machine::Machine(unsigned int _stock) {
      stock = _stock;
      state = _stock > 0 ? static_cast<AbstractState *>(new Normal())
                        : static_cast<AbstractState *>(new SoldOut());
    }
    
    Machine::~Machine() { delete state; }
    
    void Machine::sell(unsigned int quantity) { state->sell(*this, quantity); }
    
    void Machine::refill(unsigned int quantity) { state->refill(*this, quantity); }
    
    unsigned int Machine::getStock() { return stock; }
    

    The States:

    // MachineStates.h
    #pragma once
    
    #include "Machine.h"
    #include <exception>
    #include <stdexcept>
    
    class Machine;
    
    class AbstractState {
    public:
      virtual void sell(Machine &machine, unsigned int quantity) = 0;
      virtual void refill(Machine &machine, unsigned int quantity) = 0;
      virtual ~AbstractState();
    
    protected:
      void setState(Machine &machine, AbstractState *st);
      void updateStock(Machine &machine, unsigned int quantity);
    };
    
    class Normal : public AbstractState {
    public:
      virtual void sell(Machine &machine, unsigned int quantity);
      virtual void refill(Machine &machine, unsigned int quantity);
      virtual ~Normal();
    };
    
    class SoldOut : public AbstractState {
    public:
      virtual void sell(Machine &machine, unsigned int quantity);
      virtual void refill(Machine &machine, unsigned int quantity);
      virtual ~SoldOut();
    };
    
    // --------
    
    // MachineStates.cpp
    #include "MachineStates.h"
    
    AbstractState::~AbstractState() {}
    
    void AbstractState::setState(Machine &machine, AbstractState *state) {
      AbstractState *aux = machine.state;
      machine.state = state;
      delete aux;
    }
    
    void AbstractState::updateStock(Machine &machine, unsigned int quantity) {
      machine.stock = quantity;
    }
    
    Normal::~Normal() {}
    
    void Normal::sell(Machine &machine, unsigned int quantity) {
      unsigned int currStock = machine.getStock();
      if (currStock < quantity) {
        throw std::runtime_error("Not enough stock");
      }
    
      updateStock(machine, currStock - quantity);
    
      if (machine.getStock() == 0) {
        setState(machine, new SoldOut());
      }
    }
    
    void Normal::refill(Machine &machine, unsigned int quantity) {
      int currStock = machine.getStock();
      updateStock(machine, currStock + quantity);
    }
    
    SoldOut::~SoldOut() {}
    
    void SoldOut::sell(Machine &machine, unsigned int quantity) {
      throw std::runtime_error("Sold out!");
    }
    
    void SoldOut::refill(Machine &machine, unsigned int quantity) {
      updateStock(machine, quantity);
      setState(machine, new Normal());
    }
    

    I'm not used to program in C++, but this code apparently compiles against GCC 4.8.2 clang@11.0.0 and Valgrind shows no leaks, so I guess it's fine. I'm not computing money, but I don't need this to show you the idea.

    To test it:

    // main.cpp
    #include "Machine.h"
    #include "MachineStates.h"
    #include <iostream>
    #include <stdexcept>
    
    int main() {
      Machine m(10), m2(0);
    
      m.sell(10);
      std::cout << "m: "
                << "Sold 10 items" << std::endl;
    
      try {
        m.sell(1);
      } catch (std::exception &e) {
        std::cerr << "m: " << e.what() << std::endl;
      }
    
      m.refill(20);
      std::cout << "m: "
                << "Refilled 20 items" << std::endl;
    
      m.sell(10);
      std::cout << "m: "
                << "Sold 10 items" << std::endl;
      std::cout << "m: "
                << "Remaining " << m.getStock() << " items" << std::endl;
    
      m.sell(5);
      std::cout << "m: "
                << "Sold 5 items" << std::endl;
      std::cout << "m: "
                << "Remaining " << m.getStock() << " items" << std::endl;
    
      try {
        m.sell(10);
      } catch (std::exception &e) {
        std::cerr << "m: " << e.what() << std::endl;
      }
    
      try {
        m2.sell(1);
      } catch (std::exception &e) {
        std::cerr << "m2: " << e.what() << std::endl;
      }
    
      return 0;
    }
    

    A little bit of Makefile:

    CC = clang++
    CFLAGS = -g -Wall -std=c++17
    
    main: main.o Machine.o MachineStates.o
        $(CC) $(CFLAGS) -o main main.o Machine.o MachineStates.o
    
    main.o: main.cpp Machine.h MachineStates.h
        $(CC) $(CFLAGS) -c main.cpp
    
    Machine.o: Machine.h MachineStates.h
    
    MachineStates.o: Machine.h MachineStates.h
    
    clean:
        $(RM) main
    

    Then run:

    make main
    ./main
    

    Output is:

    m: Sold 10 items
    m: Sold out!
    m: Refilled 20 items
    m: Sold 10 items
    m: Remaining 10 items
    m: Sold 5 items
    m: Remaining 5 items
    m: Not enough stock
    m2: Not enough stock
    

    Now, if you want to add a Broken state, all you need is another AbstractState child:

    diff --git a/Machine.cpp b/Machine.cpp
    index 935d654..6c1f421 100644
    --- a/Machine.cpp
    +++ b/Machine.cpp
    @@ -13,4 +13,8 @@ void Machine::sell(unsigned int quantity) { state->sell(*this, quantity); }
     
     void Machine::refill(unsigned int quantity) { state->refill(*this, quantity); }
     
    +void Machine::damage() { state->damage(*this); }
    +
    +void Machine::fix() { state->fix(*this); }
    +
     unsigned int Machine::getStock() { return stock; }
    diff --git a/Machine.h b/Machine.h
    index aa983d0..706dde2 100644
    --- a/Machine.h
    +++ b/Machine.h
    @@ -12,6 +12,8 @@ public:
       Machine(unsigned int _stock);
       void sell(unsigned int quantity);
       void refill(unsigned int quantity);
    +  void damage();
    +  void fix();
       unsigned int getStock();
       ~Machine();
     
    diff --git a/MachineStates.cpp b/MachineStates.cpp
    index 9656783..d35a53d 100644
    --- a/MachineStates.cpp
    +++ b/MachineStates.cpp
    @@ -13,6 +13,16 @@ void AbstractState::updateStock(Machine &machine, unsigned int quantity) {
       machine.stock = quantity;
     }
     
    +void AbstractState::damage(Machine &machine) {
    +  setState(machine, new Broken());
    +};
    +
    +void AbstractState::fix(Machine &machine) {
    +  setState(machine, machine.stock > 0
    +                        ? static_cast<AbstractState *>(new Normal())
    +                        : static_cast<AbstractState *>(new SoldOut()));
    +};
    +
     Normal::~Normal() {}
     
     void Normal::sell(Machine &machine, unsigned int quantity) {
    @@ -33,6 +43,10 @@ void Normal::refill(Machine &machine, unsigned int quantity) {
       updateStock(machine, currStock + quantity);
     }
     
    +void Normal::fix(Machine &machine) {
    +  throw std::runtime_error("If it ain't broke, don't fix it!");
    +};
    +
     SoldOut::~SoldOut() {}
     
     void SoldOut::sell(Machine &machine, unsigned int quantity) {
    @@ -43,3 +57,17 @@ void SoldOut::refill(Machine &machine, unsigned int quantity) {
       updateStock(machine, quantity);
       setState(machine, new Normal());
     }
    +
    +void SoldOut::fix(Machine &machine) {
    +  throw std::runtime_error("If it ain't broke, don't fix it!");
    +};
    +
    +Broken::~Broken() {}
    +
    +void Broken::sell(Machine &machine, unsigned int quantity) {
    +  throw std::runtime_error("Machine is broken! Fix it before sell");
    +}
    +
    +void Broken::refill(Machine &machine, unsigned int quantity) {
    +  throw std::runtime_error("Machine is broken! Fix it before refill");
    +}
    diff --git a/MachineStates.h b/MachineStates.h
    index b117d3c..3921d35 100644
    --- a/MachineStates.h
    +++ b/MachineStates.h
    @@ -11,6 +11,8 @@ class AbstractState {
     public:
       virtual void sell(Machine &machine, unsigned int quantity) = 0;
       virtual void refill(Machine &machine, unsigned int quantity) = 0;
    +  virtual void damage(Machine &machine);
    +  virtual void fix(Machine &machine);
       virtual ~AbstractState();
     
     protected:
    @@ -22,6 +24,7 @@ class Normal : public AbstractState {
     public:
       virtual void sell(Machine &machine, unsigned int quantity);
       virtual void refill(Machine &machine, unsigned int quantity);
    +  virtual void fix(Machine &machine);
       virtual ~Normal();
     };
     
    @@ -29,5 +32,13 @@ class SoldOut : public AbstractState {
     public:
       virtual void sell(Machine &machine, unsigned int quantity);
       virtual void refill(Machine &machine, unsigned int quantity);
    +  virtual void fix(Machine &machine);
       virtual ~SoldOut();
     };
    +
    +class Broken : public AbstractState {
    +public:
    +  virtual void sell(Machine &machine, unsigned int quantity);
    +  virtual void refill(Machine &machine, unsigned int quantity);
    +  virtual ~Broken();
    +};
    diff --git a/main b/main
    index 26915c2..de2c3e5 100755
    Binary files a/main and b/main differ
    diff --git a/main.cpp b/main.cpp
    index 8c57fed..82ea0bf 100644
    --- a/main.cpp
    +++ b/main.cpp
    @@ -39,11 +39,34 @@ int main() {
         std::cerr << "m: " << e.what() << std::endl;
       }
     
    +  m.damage();
    +  std::cout << "m: "
    +            << "Machine is broken" << std::endl;
    +  m.fix();
    +  std::cout << "m: "
    +            << "Fixed! In stock: " << m.getStock() << " items" << std::endl;
    +
       try {
         m2.sell(1);
       } catch (std::exception &e) {
         std::cerr << "m2: " << e.what() << std::endl;
       }
     
    +  try {
    +    m2.fix();
    +  } catch (std::exception &e) {
    +    std::cerr << "m2: " << e.what() << std::endl;
    +  }
    +
    +  m2.damage();
    +  std::cout << "m2: "
    +            << "Machine is broken" << std::endl;
    +
    +  try {
    +    m2.refill(10);
    +  } catch (std::exception &e) {
    +    std::cerr << "m2: " << e.what() << std::endl;
    +  }
    +
       return 0;
     }
    

    To add more products, you must have a map of products and its respective in-stock quantity and so on...

    The final code can be found in this repo.