Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions chatkit/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@
import json
from datetime import datetime
from pathlib import Path
from typing import (
Annotated,
Any,
Literal,
)
from typing import Annotated, Any, Literal

from jinja2 import Environment, StrictUndefined, Template
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
model_serializer,
)
from typing_extensions import NotRequired, TypedDict, deprecated
Expand Down Expand Up @@ -1147,8 +1142,6 @@ class WidgetTemplate:
```
"""

adapter: TypeAdapter[DynamicWidgetRoot] = TypeAdapter(DynamicWidgetRoot)

def __init__(self, definition: dict[str, Any]):
self.version = definition["version"]
if self.version != "1.0":
Expand All @@ -1163,7 +1156,7 @@ def __init__(self, definition: dict[str, Any]):
self.data_schema = definition.get("jsonSchema", {})

@classmethod
def from_file(cls, file_path: str) -> "WidgetTemplate":
def from_file(cls, file_path: str) -> WidgetTemplate:
path = Path(file_path)
if not path.is_absolute():
caller_frame = inspect.stack()[1]
Expand All @@ -1178,10 +1171,19 @@ def from_file(cls, file_path: str) -> "WidgetTemplate":
def build(
self, data: dict[str, Any] | BaseModel | None = None
) -> DynamicWidgetRoot:
if data is None:
data = {}
if isinstance(data, BaseModel):
data = data.model_dump()
rendered = self.template.render(**data)
rendered = self.template.render(**self._normalize_data(data))
widget_dict = json.loads(rendered)
return self.adapter.validate_python(widget_dict)
return DynamicWidgetRoot.model_validate(widget_dict)

def build_basic(self, data: dict[str, Any] | BaseModel | None = None) -> BasicRoot:
"""Separate method for building basic root widgets until BasicRoot is supported for streamed widgets."""
rendered = self.template.render(**self._normalize_data(data))
widget_dict = json.loads(rendered)
return BasicRoot.model_validate(widget_dict)

def _normalize_data(
self, data: dict[str, Any] | BaseModel | None
) -> dict[str, Any]:
if data is None:
return {}
return data.model_dump() if isinstance(data, BaseModel) else data
21 changes: 21 additions & 0 deletions tests/assets/widgets/basic_root.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"type": "Basic",
"children": [
{
"type": "Col",
"gap": 1,
"children": [
{
"type": "Title",
"value": "Harry Potter",
"size": "sm"
},
{
"type": "Text",
"value": "The boy who lived",
"size": "sm"
}
]
}
]
}
46 changes: 46 additions & 0 deletions tests/assets/widgets/basic_root.widget
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"version": "1.0",
"name": "Author preview",
"template": "{\"type\":\"Basic\",\"children\":[{\"type\":\"Col\",\"gap\":1,\"children\":[{\"type\":\"Title\",\"value\":{{ (name) | tojson }},\"size\":\"sm\"},{\"type\":\"Text\",\"value\":{{ (bio) | tojson }},\"size\":\"sm\"}]}]}",
"jsonSchema": {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1
},
"bio": {
"type": "string",
"minLength": 1
}
},
"required": [
"name",
"bio"
],
"additionalProperties": false
},
"outputJsonPreview": {
"type": "Basic",
"children": [
{
"type": "Col",
"gap": 1,
"children": [
{
"type": "Title",
"value": "Harry Potter",
"size": "sm"
},
{
"type": "Text",
"value": "The boy who lived",
"size": "sm"
}
]
}
]
},
"encodedWidget": "eyJpZCI6IndpZ194djV6dGlxayIsIm5hbWUiOiJBdXRob3IgcHJldmlldyIsInZpZXciOiI8QmFzaWM-XG4gIDxDb2wgZ2FwPXsxfT5cbiAgICA8VGl0bGUgdmFsdWU9e25hbWV9IHNpemU9XCJzbVwiIC8-XG4gICAgPFRleHQgdmFsdWU9e2Jpb30gc2l6ZT1cInNtXCIgLz5cbiAgPC9Db2w-XG48L0Jhc2ljPiIsImRlZmF1bHRTdGF0ZSI6eyJuYW1lIjoiSGFycnkgUG90dGVyIiwiYmlvIjoiVGhlIGJveSB3aG8gbGl2ZWQifSwic2NoZW1hTW9kZSI6InpvZCIsImpzb25TY2hlbWEiOnsidHlwZSI6Im9iamVjdCIsInByb3BlcnRpZXMiOnsidGl0bGUiOnsidHlwZSI6InN0cmluZyJ9fSwicmVxdWlyZWQiOlsidGl0bGUiXSwiYWRkaXRpb25hbFByb3BlcnRpZXMiOmZhbHNlfSwic2NoZW1hIjoiaW1wb3J0IHsgeiB9IGZyb20gXCJ6b2RcIlxuXG5jb25zdCBQcm9maWxlID0gei5vYmplY3Qoe1xuICBuYW1lOiB6LnN0cmluZygpLnRyaW0oKS5taW4oMSksXG4gIGJpbzogei5zdHJpbmcoKS50cmltKCkubWluKDEpLFxufSk7XG5cbmV4cG9ydCBkZWZhdWx0IFByb2ZpbGU7Iiwic3RhdGVzIjpbXSwic2NoZW1hVmFsaWRpdHkiOiJ2YWxpZCIsInZpZXdWYWxpZGl0eSI6InZhbGlkIiwiZGVmYXVsdFN0YXRlVmFsaWRpdHkiOiJ2YWxpZCJ9"
}
18 changes: 18 additions & 0 deletions tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chatkit.server import diff_widget
from chatkit.types import WidgetItem
from chatkit.widgets import (
BasicRoot,
Card,
DynamicWidgetComponent,
DynamicWidgetRoot,
Expand Down Expand Up @@ -241,3 +242,20 @@ def test_widget_template_from_file(

assert isinstance(widget, DynamicWidgetRoot)
assert widget.model_dump(exclude_none=True) == expected_widget_dict


def test_widget_template_with_basic_root():
template = WidgetTemplate.from_file("assets/widgets/basic_root.widget")

with open("tests/assets/widgets/basic_root.json", "r") as file:
expected_widget_dict = json.load(file)

widget = template.build_basic(
{
"name": "Harry Potter",
"bio": "The boy who lived",
},
)

assert isinstance(widget, BasicRoot)
assert widget.model_dump(exclude_none=True) == expected_widget_dict