Skip to content

Commit

Permalink
Update pydantic to 2.0 (#9121)
Browse files Browse the repository at this point in the history
* bump pydantic version

* Bump dataprep

* Update nixtla

* update readme

* remove log

* Missed renames

* Fix nixtla forecast
  • Loading branch information
hamishfagg committed May 7, 2024
1 parent 255fb87 commit 36b3aab
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion mindsdb/integrations/handlers/timegpt_handler/README.md
@@ -1,5 +1,5 @@
# Briefly describe what ML framework does this handler integrate to MindsDB, and how?
TimeGPT is a zero-shot forecasting model developed by Nixtla, offered through their `nixtlats` Python package.
TimeGPT is a zero-shot forecasting model developed by Nixtla, offered through their `nixtla` Python package.

This handler provides a simple wrapper around TimeGPT, and provides easy ingestion of time series data from other MindsDB data sources into the TimeGPT API. User requires a Nixtla API key to use this handler.

Expand Down
@@ -1 +1 @@
nixtlats>=0.1.10
nixtla==0.5.0
12 changes: 6 additions & 6 deletions mindsdb/integrations/handlers/timegpt_handler/timegpt_handler.py
@@ -1,7 +1,7 @@
from typing import Optional, Dict

import pandas as pd
from nixtlats import TimeGPT
from nixtla import NixtlaClient

from mindsdb.integrations.libs.base import BaseMLEngine
from mindsdb.integrations.utilities.handler_utils import get_api_key
Expand Down Expand Up @@ -34,8 +34,8 @@ def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[
assert time_settings["is_timeseries"], "Specify time series settings in your query"

timegpt_token = get_api_key('timegpt', using_args, self.engine_storage, strict=True)
timegpt = TimeGPT(token=timegpt_token)
assert timegpt.validate_token(), "Invalid TimeGPT token provided."
timegpt = NixtlaClient(api_key=timegpt_token)
assert timegpt.validate_api_key(), "Invalid TimeGPT token provided."

model_args = {
'token': timegpt_token,
Expand Down Expand Up @@ -74,8 +74,8 @@ def predict(self, df, args={}):
args = args['predict_params']
prediction_df = self._transform_to_nixtla_df(df, model_args)

timegpt = TimeGPT(token=model_args['token'])
assert timegpt.validate_token(), "Invalid TimeGPT token provided."
timegpt = NixtlaClient(api_key=model_args['token'])
assert timegpt.validate_api_key(), "Invalid TimeGPT token provided."

forecast_df = timegpt.forecast(
prediction_df,
Expand All @@ -85,7 +85,7 @@ def predict(self, df, args={}):
freq=args.get("freq", model_args["freq"]), # automatically infers correct frequency if not provided by user
level=model_args["level"],
finetune_steps=args.get('finetune_steps', model_args['finetune_steps']),
validate_token=args.get('validate_token', model_args['validate_token']),
validate_api_key=args.get('validate_token', model_args['validate_token']),
date_features=args.get('date_features', model_args['date_features']),
date_features_to_one_hot=args.get('date_features_to_one_hot', model_args['date_features_to_one_hot']),
clean_ex_first=args.get('clean_ex_first', model_args['clean_ex_first']),
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Expand Up @@ -18,13 +18,13 @@ walrus==0.9.3
flask-compress >= 1.0.0
appdirs >= 1.0.0
mindsdb-sql ~= 0.15.0
pydantic >= 1.8.2
pydantic >= 2.0.3
mindsdb-evaluator >= 0.0.7, < 0.1.0
checksumdir >= 1.2.0
duckdb == 0.9.1
requests >= 2.30.0
pydateinfer==0.3.0
dataprep_ml==0.0.22
dataprep_ml==0.0.23
dill == 0.3.6
numpy
pytz
Expand Down

0 comments on commit 36b3aab

Please sign in to comment.