I am integrating an AI model into a web app (this model needs to have some context in order to maintain a fluid conversation with the user) on a local deployment. The problem is the intrinsic structure of the thread.
I know how a pool thread works. And the problem is that, when doing multiple POST request (for chatting with the bot), there is a posibility that another thread that has not been used answers to that request. Then, the context will be saved in the memory of that thread, not in the one that we have been previously using.
Main problem: context is being saved in different threads with different memory each.
Firstly, I want to mention that none solution should not be either implemeting cookies or creating a file for saving the context. The idea is to assign a thread per session_token.
I have tried the following:
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
from openai import AzureOpenAI
import os
from typing import Dict, List
from uuid import UUID, uuid4
load_dotenv()
router = APIRouter()
class Message(BaseModel):
"""Class representing the message of a conversation"""
role: str # either 'user' or 'assistant'
content: str
class Prompt(BaseModel):
"""Class for sending or receiving messages"""
session_id: UUID
prompt: str
class NewSessionResponse(BaseModel):
session_id: UUID
class ResetRequest(BaseModel):
session_id: UUID
# Dictionary to store conversations
conversations: Dict[UUID, List[Message]] = {}
@router.post("/ai/chat")
async def chat(prompt: Prompt):
"""In charge of executing and obtaining the connection with the model"""
api_key = os.getenv("key")
api_url = os.getenv("endpoint url")
if not api_key or not api_url:
raise HTTPException(status_code=500, detail='API key or API endpoint not found! Try again')
client = AzureOpenAI("here goes some parameters")
# Retrieve the conversation history for the session
session_id = prompt.session_id
if session_id not in conversations:
conversations[session_id] = []
# Add the user's prompt to the conversation history
conversations[session_id].append(Message(role="user", content=prompt.prompt))
# Create the context for the API request
context = [{'role': msg.role, 'content': msg.content} for msg in conversations[session_id]]
# Request completion from the model
response = client.chat.completions.create(
model="gpt-35-turbo-4k-0613",
messages=context
)
# Extract the model's response
model_response = response.choices[0].message.content
# Add the assistant's response to the conversation history
conversations[session_id].append(Message(role='assistant', content=model_response))
print(conversations)
return {"response": model_response}
@router.post("/ai/new_session", response_model=NewSessionResponse)
async def newSession():
"""Creates a new conversation session"""
session_id = uuid4()
conversations[session_id] = []
return NewSessionResponse(session_id=session_id)
I also try to implement threading when managing requests, but it did not work. I suppose that this happens because the Uvicorn threads and the threads created here are not the same.
I implemented a Redis database for saving the messages depending on the sessionID that the endpoint receives.
Notice that I had to modify the docker-compose.yaml file for connecting into the DB.
Here you have the code:
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
from openai import AzureOpenAI
import os
from typing import List
from uuid import UUID, uuid4
import json
from app.core.redis_config import redis_db
load_dotenv()
router = APIRouter()
class Message(BaseModel):
role: str
content: str
class Prompt(BaseModel):
session_id: UUID
prompt: str
ai_model: str
class NewSessionResponse(BaseModel):
session_id: UUID
class ResetRequest(BaseModel):
session_id: UUID
def get_conversation(session_id: UUID) -> List[Message]:
"""In charge of loading the conversation from the ddbb"""
data = redis_db.get(str(session_id))
if data:
return [Message(**msg) for msg in json.loads(data)]
return []
def save_conversation(session_id: UUID, messages: List[Message]):
"""In charge of inserting the conversation into the ddbb"""
redis_db.set(str(session_id), json.dumps([msg.dict() for msg in messages])) # TODO: works for the moment with this although is deprecated
@router.post("/ai/chat")
async def chat(prompt: Prompt):
api_key = os.getenv("OPENAI_API_KEY", "")
api_url = os.getenv("OPENAI_API_BASE", "")
ai_model = prompt.ai_model
if not api_key or not api_url:
raise HTTPException(status_code=500, detail='API key or API endpoint not found! Try again')
client = AzureOpenAI("some parameters")
session_id = prompt.session_id
conversation = get_conversation(session_id)
conversation.append(Message(role="user", content=prompt.prompt))
context = [{'role': msg.role, 'content': msg.content} for msg in conversation]
response = client.chat.completions.create(
model=ai_model,
messages=context
)
model_response = response.choices[0].message.content
conversation.append(Message(role='assistant', content=model_response))
save_conversation(session_id, conversation)
return {"response": model_response}
@router.post("/ai/reset")
async def resetConversation(reset_request: ResetRequest):
session_id = reset_request.session_id
if redis_db.exists(str(session_id)):
redis_db.delete(str(session_id))
else:
raise HTTPException(status_code=400, detail='Invalid session ID')
return {"response": True}
@router.post("/ai/new_session", response_model=NewSessionResponse)
async def newSession():
session_id = uuid4()
save_conversation(session_id, [])
return NewSessionResponse(session_id=session_id)