Skip to content

Commit

Permalink
chat: fix issues with quickly switching between multiple chats (#2343)
Browse files Browse the repository at this point in the history
* prevent load progress from getting out of sync with the current chat
* fix memory leak on exit if the LLModelStore contains a model
* do not report cancellation as a failure in console/Mixpanel
* show "waiting for model" separately from "switching context" in UI
* do not show lower "reload" button on error
* skip context switch if unload is pending
* skip unnecessary calls to LLModel::saveState

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
  • Loading branch information
cebtenzzre committed May 15, 2024
1 parent 7f1c3d4 commit 7e1e00f
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 143 deletions.
41 changes: 21 additions & 20 deletions gpt4all-chat/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::trySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);

connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection);
Expand Down Expand Up @@ -95,16 +95,6 @@ void Chat::processSystemPrompt()
emit processSystemPromptRequested();
}

bool Chat::isModelLoaded() const
{
return m_modelLoadingPercentage == 1.0f;
}

float Chat::modelLoadingPercentage() const
{
return m_modelLoadingPercentage;
}

void Chat::resetResponseState()
{
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
Expand Down Expand Up @@ -167,9 +157,16 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
if (loadingPercentage == m_modelLoadingPercentage)
return;

bool wasLoading = isCurrentlyLoading();
bool wasLoaded = isModelLoaded();

m_modelLoadingPercentage = loadingPercentage;
emit modelLoadingPercentageChanged();
if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f)

if (isCurrentlyLoading() != wasLoading)
emit isCurrentlyLoadingChanged();

if (isModelLoaded() != wasLoaded)
emit isModelLoadedChanged();
}

Expand Down Expand Up @@ -247,10 +244,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
if (m_modelInfo == modelInfo && isModelLoaded())
return;

m_modelLoadingPercentage = std::numeric_limits<float>::min(); // small non-zero positive value
emit isModelLoadedChanged();
m_modelLoadingError = QString();
emit modelLoadingErrorChanged();
m_modelInfo = modelInfo;
emit modelInfoChanged();
emit modelChangeRequested(modelInfo);
Expand Down Expand Up @@ -320,8 +313,9 @@ void Chat::forceReloadModel()

void Chat::trySwitchContextOfLoadedModel()
{
emit trySwitchContextOfLoadedModelAttempted();
m_llmodel->setShouldTrySwitchContext(true);
m_trySwitchContextInProgress = 1;
emit trySwitchContextInProgressChanged();
m_llmodel->requestTrySwitchContext();
}

void Chat::generatedNameChanged(const QString &name)
Expand All @@ -342,8 +336,10 @@ void Chat::handleRecalculating()

void Chat::handleModelLoadingError(const QString &error)
{
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
stream.quote() << id();
if (!error.isEmpty()) {
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
stream.quote() << id();
}
m_modelLoadingError = error;
emit modelLoadingErrorChanged();
}
Expand Down Expand Up @@ -380,6 +376,11 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
emit modelInfoChanged();
}

void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) {
m_trySwitchContextInProgress = value;
emit trySwitchContextInProgressChanged();
}

bool Chat::serialize(QDataStream &stream, int version) const
{
stream << m_creationDate;
Expand Down
17 changes: 13 additions & 4 deletions gpt4all-chat/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Chat : public QObject
Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(bool isCurrentlyLoading READ isCurrentlyLoading NOTIFY isCurrentlyLoadingChanged)
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
Expand All @@ -30,6 +31,8 @@ class Chat : public QObject
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
// 0=no, 1=waiting, 2=working
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")

Expand Down Expand Up @@ -62,8 +65,9 @@ class Chat : public QObject

Q_INVOKABLE void reset();
Q_INVOKABLE void processSystemPrompt();
Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE float modelLoadingPercentage() const;
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
Q_INVOKABLE void prompt(const QString &prompt);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating();
Expand Down Expand Up @@ -105,6 +109,8 @@ class Chat : public QObject
QString device() const { return m_device; }
QString fallbackReason() const { return m_fallbackReason; }

int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }

public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt);

Expand All @@ -113,6 +119,7 @@ public Q_SLOTS:
void nameChanged();
void chatModelChanged();
void isModelLoadedChanged();
void isCurrentlyLoadingChanged();
void modelLoadingPercentageChanged();
void modelLoadingWarning(const QString &warning);
void responseChanged();
Expand All @@ -136,8 +143,7 @@ public Q_SLOTS:
void deviceChanged();
void fallbackReasonChanged();
void collectionModelChanged();
void trySwitchContextOfLoadedModelAttempted();
void trySwitchContextOfLoadedModelCompleted(bool);
void trySwitchContextInProgressChanged();

private Q_SLOTS:
void handleResponseChanged(const QString &response);
Expand All @@ -152,6 +158,7 @@ private Q_SLOTS:
void handleFallbackReasonChanged(const QString &device);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleTrySwitchContextOfLoadedModelCompleted(int value);

private:
QString m_id;
Expand All @@ -176,6 +183,8 @@ private Q_SLOTS:
float m_modelLoadingPercentage = 0.0f;
LocalDocsCollectionsModel *m_collectionModel;
bool m_firstResponse = true;
int m_trySwitchContextInProgress = 0;
bool m_isCurrentlyLoading = false;
};

#endif // CHAT_H
6 changes: 5 additions & 1 deletion gpt4all-chat/chatlistmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ class ChatListModel : public QAbstractListModel
int count() const { return m_chats.size(); }

// stop ChatLLM threads for clean shutdown
void destroyChats() { for (auto *chat: m_chats) { chat->destroy(); } }
void destroyChats()
{
for (auto *chat: m_chats) { chat->destroy(); }
ChatLLM::destroyStore();
}

void removeChatFile(Chat *chat) const;
Q_INVOKABLE void saveChats();
Expand Down

0 comments on commit 7e1e00f

Please sign in to comment.