Search code examples
pythondiscorddiscord.pychatbotcode-cleanup

Discord.py How to make clean dialog trees?


My goal is to clean up my code so that I can more easily make dialog trees without constant copied pieces that don't have to be there. I can do it cleanly in python, but discord.py seems to have different requirements. Here is a sample of my current, very redundant code:

    if 'I need help' in message.content.lower():
        await message.channel.trigger_typing()
        await asyncio.sleep(2)
        response = 'Do you need help'
        await message.channel.send(response)
        await message.channel.send("yes or no?")

        def check(msg):
            return msg.author == message.author and msg.channel == message.channel and msg.content.lower() in ["yes", "no"]
        msg = await client.wait_for("message", check=check)

        if msg.content.lower() == "no":
            await message.channel.trigger_typing()
            await asyncio.sleep(2)
            response = 'okay'
            await message.channel.send(response)

        if msg.content.lower() == "yes":
            await message.channel.trigger_typing()
            await asyncio.sleep(2)
            response = 'I have something. Would you like to continue?'
            await message.channel.send(response)
            await message.channel.send("yes or no?")

            def check(msg):
                return msg.author == message.author and msg.channel == message.channel and msg.content.lower() in ["yes", "no"]
            msg = await client.wait_for("message", check=check)

            if msg.content.lower() == "no":
                await message.channel.trigger_typing()
                await asyncio.sleep(2)
                response = 'Okay'
                await message.channel.send(response)

I've tried to make functions to handle the repeating code, but haven't been successful. For example, using:

async def respond(response, channel):
    await channel.trigger_typing()
    await asyncio.sleep(2)
    await channel.send(response)
...
await respond(response, message.channel)

Ideally, I'd like to be able to do something like this for the tree dialog itself, as I can in python:

if __name__=='__main__':
    hallucinated = {
        1: {
          'Text': [
                "It sounds like you may be hallucinating, would you like help with trying to disprove it?"
            ],
          'Options': [
              ("yes", 2),
              ("no", 3)
            ]
        },
        2: {    
            'Text': [
                "Is it auditory, visual, or tactile?"
            ],
            'Options': [
              ("auditory", 4),
              ("visual", 5),
              ("tactile", 6)
            ]
        }
    }

Solution

  • Your general idea is correct: it is possible to represent such a system with an structure similar to the one you described. It's called a finite state machine. I've written an example of how one of these might be implemented -- this particular one uses a structure similar to an interactive fiction like Zork, but the same principle can apply to dialog trees as well.

    from typing import Tuple, Mapping, Callable, Optional, Any
    import traceback
    import discord
    import logging
    import asyncio
    logging.basicConfig(level=logging.DEBUG)
    
    client = discord.Client()
    
    NodeId = str
    
    ABORT_COMMAND = '!abort'
    
    class BadFSMError(ValueError):
        """ Base class for exceptions that occur while evaluating the dialog FSM. """
    
    class FSMAbortedError(BadFSMError):
        """ Raised when the user aborted the execution of a FSM. """
    
    class LinkToNowhereError(BadFSMError):
        """ Raised when a node links to another node that doesn't exist. """
    
    class NoEntryNodeError(BadFSMError):
        """ Raised when the entry node is unset. """
    
    class Node:
        """ Node in the dialog FSM. """
        def __init__(self,
                     text_on_enter: Optional[str],
                     choices: Mapping[str, Tuple[NodeId, Callable[[Any], None]]],
                     delay_before_text: int = 2, is_exit_node: bool = False):
            self.text_on_enter = text_on_enter
            self.choices = choices
            self.delay_before_text = delay_before_text
            self.is_exit_node = is_exit_node
    
        async def walk_from(self, message) -> Optional[NodeId]:
            """ Get the user's input and return the next node in the FSM that the user went to. """
            async with message.channel.typing():
                await asyncio.sleep(self.delay_before_text)
            if self.text_on_enter:
                await message.channel.send(self.text_on_enter)
    
            if self.is_exit_node: return None
    
            def is_my_message(msg):
                return msg.author == message.author and msg.channel == message.channel
            user_message = await client.wait_for("message", check=is_my_message)
            choice = user_message.content
            while choice not in self.choices:
                if choice == ABORT_COMMAND: raise FSMAbortedError
                await message.channel.send("Please select one of the following: " + ', '.join(list(self.choices)))       
                user_message = await client.wait_for("message", check=is_my_message)
                choice = user_message.content
    
            result = self.choices[choice]
            if isinstance(result, tuple):
                next_id, mod_func = self.choices[choice]
                mod_func(self)
            else: next_id = result
            return next_id
    
    class DialogFSM:
        """ Dialog finite state machine. """
        def __init__(self, nodes={}, entry_node=None):
            self.nodes: Mapping[NodeId, Node] = nodes
            self.entry_node: NodeId = entry_node
    
        def add_node(self, id: NodeId, node: Node):
            """ Add a node to the FSM. """
            if id in self.nodes: raise ValueError(f"Node with ID {id} already exists!")
            self.nodes[id] = node
    
        def set_entry(self, id: NodeId):
            """ Set entry node. """ 
            if id not in self.nodes: raise ValueError(f"Tried to set unknown node {id} as entry")
            self.entry_node = id
    
        async def evaluate(self, message):
            """ Evaluate the FSM, beginning from this message. """
            if not self.entry_node: raise NoEntryNodeError
            current_node = self.nodes[self.entry_node]
            while current_node is not None:
                next_node_id = await current_node.walk_from(message)
                if next_node_id is None: return
                if next_node_id not in self.nodes: raise LinkToNowhereError(f"A node links to {next_node_id}, which doesn't exist")
                current_node = self.nodes[next_node_id]
    
    
    def break_glass(node):
        node.text_on_enter = "You are in a blue room. The remains of a shattered stained glass ceiling are scattered around. There is a step-ladder you can use to climb out."
        del node.choices['break']
        node.choices['u'] = 'exit'
    nodes = {
        'central': Node("You are in a white room. There are doors leading east, north, and a ladder going up.", {'n': 'xroom', 'e': 'yroom', 'u': 'zroom'}),
        'xroom': Node("You are in a red room. There is a large 'X' on the wall in front of you. The only exit is south.", {'s': 'central'}),
        'yroom': Node("You are in a green room. There is a large 'Y' on the wall to the right. The only exit is west.", {'w': 'central'}),
        'zroom': Node("You are in a blue room. There is a large 'Z' on the stained glass ceiling. There is a step-ladder and a hammer.", {'d': 'central', 'break': ('zroom', break_glass)}),
        'exit': Node("You have climbed out into a forest. You see the remains of a glass ceiling next to you. You are safe now.", {}, is_exit_node=True)
    }
    
    fsm = DialogFSM(nodes, 'central')
    
    @client.event
    async def on_message(msg):
        if msg.content == '!begin':
           try:
               await fsm.evaluate(msg)
               await msg.channel.send("FSM terminated successfully")
           except:
               await msg.channel.send(traceback.format_exc())
    
    client.run("token")
    

    Here's a sample run:

    Screenshot of an interaction with the bot in Discord