-
Notifications
You must be signed in to change notification settings - Fork 778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade to Pydantic V2 #650
base: main
Are you sure you want to change the base?
Changes from 6 commits
62422f0
6ea256a
5af454d
914137b
2d137bb
7b8528e
15aa9f7
2f3bad0
cffbf2a
9a1002d
82c0333
ae3610f
69ef211
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,9 @@ | |
import json | ||
import os | ||
from contextlib import contextmanager | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast | ||
from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast | ||
|
||
from pydantic import BaseModel, Field, validator | ||
from pydantic import ConfigDict, BaseModel, Field, field_validator | ||
|
||
from ..libs.constants.default_config import default_config_json | ||
from ..libs.llm.base import LLM | ||
|
@@ -28,6 +28,7 @@ | |
from .context import ContextProvider | ||
from .main import ContextProviderDescription, Policy, SlashCommandDescription, Step | ||
from .models import MODEL_CLASSES, Models | ||
from typing import Iterator | ||
|
||
|
||
class StepWithParams(BaseModel): | ||
|
@@ -43,11 +44,11 @@ class ContextProviderWithParams(BaseModel): | |
class SlashCommand(BaseModel): | ||
name: str | ||
description: str | ||
step: Union[Type[Step], StepName, str] | ||
step: Annotated[Union[Type[Step], StepName], Field()] = Field(default=None, validate_default=True) | ||
params: Optional[Dict] = {} | ||
|
||
# Allow step class for the migration | ||
@validator("step", pre=True, always=True) | ||
@field_validator("step") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pydantic Docs: https://docs.pydantic.dev/latest/migration/#validator-and-root_validator-are-deprecated |
||
def step_is_string(cls, v): | ||
if isinstance(v, str): | ||
return v | ||
|
@@ -257,10 +258,22 @@ class SerializedContinueConfig(BaseModel): | |
@staticmethod | ||
@contextmanager | ||
def edit_config(): | ||
config = SerializedContinueConfig.parse_file(CONFIG_JSON_PATH) | ||
# Read the JSON file and parse it into a dictionary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pydantic Docs: https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel |
||
with open(CONFIG_JSON_PATH, 'r', encoding='utf-8') as file: | ||
data = json.load(file) | ||
|
||
# Create an instance of SerializedContinueConfig from the dictionary | ||
config = SerializedContinueConfig(**data) | ||
|
||
# Yield the config object for editing within the with-block | ||
yield config | ||
with open(CONFIG_JSON_PATH, "w") as f: | ||
f.write(config.json(exclude_none=True, exclude_defaults=True, indent=2)) | ||
|
||
# After editing, write the serialized config back to the JSON file | ||
with open(CONFIG_JSON_PATH, "w", encoding='utf-8') as file: | ||
# Serialize the Pydantic model instance (`dict` method creates a serializable output) | ||
json.dump(config.dict(), file, indent=4) | ||
|
||
|
||
|
||
@staticmethod | ||
def set_temperature(temperature: float): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
from typing import Any, Awaitable, Callable, List, Optional | ||
|
||
from meilisearch_python_async import Client | ||
from pydantic import BaseModel, Field | ||
from pydantic import ConfigDict, BaseModel, Field | ||
|
||
from ..libs.util.create_async_task import create_async_task | ||
from ..libs.util.devdata import dev_data_logger | ||
|
@@ -79,10 +79,7 @@ class ContextProvider(BaseModel): | |
selected_items: List[ContextItem] = Field( | ||
[], description="List of selected items in the ContextProvider" | ||
) | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
exclude = {"ide", "delete_documents", "update_documents"} | ||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude={"ide", "delete_documents", "update_documents"}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was updated by Pydantic's Pydantic Doc: |
||
|
||
def get_description(self) -> ContextProviderDescription: | ||
return ContextProviderDescription( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
import json | ||
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union, cast | ||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Optional, Union, cast | ||
|
||
from pydantic import BaseModel, Field, validator | ||
from pydantic.schema import schema | ||
from pydantic import field_validator, ConfigDict, BaseModel, Field, field_validator | ||
from pydantic.json_schema import model_json_schema | ||
|
||
from ..models.main import ContinueBaseModel | ||
from .observation import Observation | ||
|
@@ -20,13 +20,14 @@ class ChatMessage(ContinueBaseModel): | |
content: str = "" | ||
name: Optional[str] = None | ||
# A summary for pruning chat context to fit context window. Often the Step name. | ||
summary: str = Field(default=None, title="Summary") | ||
summary: Annotated[str, Field()] = Field(default=None, title="Summary", validate_default=True) | ||
function_call: Optional[FunctionCall] = None | ||
|
||
@validator("summary", pre=True, always=True) | ||
def summary_is_content(cls, summary, values): | ||
|
||
@field_validator("summary") | ||
def summary_is_content(cls, summary, val_info): | ||
if summary is None: | ||
return values.get("content", "") | ||
return val_info.data.get("content", "") | ||
return summary | ||
|
||
def to_dict(self, with_functions: bool = False) -> Dict[str, str]: | ||
|
@@ -84,7 +85,7 @@ def traverse(obj): | |
|
||
def step_to_json_schema(step) -> Dict[str, Any]: | ||
pydantic_class = step.__class__ | ||
schema_data = schema([pydantic_class]) | ||
schema_data = model_json_schema([pydantic_class]) | ||
resolved_schema = cast(Dict[str, Any], resolve_refs(schema_data)) | ||
parameters = resolved_schema["definitions"][pydantic_class.__name__] | ||
for parameter in unincluded_parameters: | ||
|
@@ -142,18 +143,18 @@ def dict(self, *args, **kwargs): | |
|
||
|
||
class StepDescription(BaseModel): | ||
step_type: str | ||
name: str | ||
description: str | ||
step_type: Optional[str] = None | ||
name: Optional[str] = None | ||
description: Optional[str] = None | ||
|
||
params: Dict[str, Any] | ||
params: Optional[Dict[str, Any]] | ||
|
||
hide: bool | ||
depth: int | ||
hide: Optional[bool] = None | ||
depth: Optional[int] = None | ||
|
||
error: Optional[ContinueError] = None | ||
observations: List[Observation] = [] | ||
logs: List[str] = [] | ||
observations: Optional[List[Observation]] = [] | ||
logs: Optional[List[str]] = [] | ||
|
||
def update(self, update: "UpdateStep"): | ||
if isinstance(update, DeltaStep): | ||
|
@@ -168,9 +169,9 @@ class SessionUpdate(BaseModel): | |
index: int | ||
update: "UpdateStep" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
stop: Optional[bool] = None | ||
|
||
class Config: | ||
smart_union = True | ||
# TODO[pydantic]: The following keys were removed: `smart_union`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Says smart_union is deprecated but no suggestion on how to migrate it https://docs.pydantic.dev/latest/migration/#changes-to-config |
||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. | ||
model_config = ConfigDict() | ||
|
||
def dict(self, *args, **kwargs): | ||
d = super().dict(*args, **kwargs) | ||
|
@@ -201,7 +202,8 @@ class ContextItemId(BaseModel): | |
provider_title: str | ||
item_id: str | ||
|
||
@validator("provider_title", "item_id") | ||
@field_validator("provider_title", "item_id") | ||
@classmethod | ||
def must_be_valid_id(cls, v): | ||
import re | ||
|
||
|
@@ -247,7 +249,8 @@ class ContextItem(BaseModel): | |
description: ContextItemDescription | ||
content: str | ||
|
||
@validator("content", pre=True) | ||
@field_validator("content", mode="before") | ||
@classmethod | ||
def content_must_be_string(cls, v): | ||
if v is None: | ||
return "" | ||
|
@@ -265,11 +268,9 @@ class SessionInfo(ContinueBaseModel): | |
|
||
|
||
class ContinueConfig(ContinueBaseModel): | ||
system_message: Optional[str] | ||
temperature: Optional[float] | ||
|
||
class Config: | ||
extra = "allow" | ||
system_message: Optional[str] = None | ||
temperature: Optional[float] = None | ||
model_config = ConfigDict(extra="allow") | ||
|
||
def dict(self, **kwargs): | ||
original_dict = super().dict(**kwargs) | ||
|
@@ -317,22 +318,25 @@ def next( | |
|
||
|
||
class Step(ContinueBaseModel): | ||
name: str = Field(default=None, title="Name") | ||
name: Optional[Annotated[str, Field()]] =Field(default=None, title="Name", validate_default=True) | ||
|
||
hide: bool = False | ||
description: str = "" | ||
|
||
class_name: str = "Step" | ||
class_name: Annotated[str, Field()] = Field(default="Step", validate_default=True) | ||
|
||
|
||
@validator("class_name", pre=True, always=True) | ||
@field_validator("class_name") | ||
def class_name_is_class_name(cls, class_name): | ||
return cls.__name__ | ||
|
||
system_message: Union[str, None] = None | ||
chat_context: List[ChatMessage] = [] | ||
manage_own_chat_context: bool = False | ||
|
||
class Config: | ||
copy_on_model_validation = False | ||
# TODO[pydantic]: The following keys were removed: `copy_on_model_validation`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy_on_model_validation deprecated but no replacment https://docs.pydantic.dev/latest/migration/#changes-to-config |
||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. | ||
# P2MN: https://github.com/pydantic/pydantic/discussions/7225 There appears to be no replacement for this | ||
model_config = ConfigDict() | ||
|
||
async def describe(self, models: Models) -> str: | ||
if self.description is not None: | ||
|
@@ -348,7 +352,8 @@ def dict(self, *args, **kwargs): | |
d["description"] = self.description or "" | ||
return d | ||
|
||
@validator("name", pre=True, always=True) | ||
|
||
@field_validator("name") | ||
def name_is_class_name(cls, name): | ||
if name is None: | ||
return cls.__name__ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import uuid | ||
from typing import Any, Callable, Dict, List, Optional, Type, Union | ||
from typing import Annotated, Any, Callable, Dict, List, Optional, Type, Union | ||
|
||
from pydantic import BaseModel, validator | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from ..libs.llm.anthropic import AnthropicLLM | ||
from ..libs.llm.base import LLM | ||
|
@@ -74,24 +74,24 @@ class Models(BaseModel): | |
|
||
default: Union[Any, LLM] | ||
summarize: Union[Any, LLM] | ||
edit: Union[Any, LLM] | ||
chat: Union[Any, LLM] | ||
summarize: Annotated[Union[Any, LLM], Field()] =Field(validate_default=True) | ||
edit: Annotated[Union[Any, LLM], Field()] =Field(validate_default=True) | ||
chat: Annotated[Union[Any, LLM], Field()] =Field(validate_default=True) | ||
|
||
saved: List[Union[Any, LLM]] = [] | ||
|
||
temperature: Optional[float] = None | ||
system_message: Optional[str] = None | ||
|
||
@validator( | ||
|
||
@field_validator( | ||
"summarize", | ||
"edit", | ||
"chat", | ||
pre=True, | ||
always=True, | ||
"chat" | ||
) | ||
def roles_not_none(cls, v, values): | ||
def roles_not_none(cls, v, val_info): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value: ModelField is now ValidationInfo. I renamed the argument from val_info to make it clear it is not value. |
||
if v is None: | ||
return values["default"] | ||
return cls.model_fields[val_info.field_name].default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no default in the val_info so it now needs to be fetched from the class.model_fields... |
||
return v | ||
|
||
def dict(self, **kwargs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uplifted openai v0.27.5 to v1.3.6 which was a breaking change.
As mentioned we'll need to uplift the entire OpenAI framework but it's a lot easier to implement