Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AZURE support #186

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 34 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.

## 🛠 Getting Started

The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API.
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.

Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it):
Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):

For OpenAI Version:

```bash
cd backend
Expand All @@ -38,6 +40,21 @@ poetry shell
poetry run uvicorn main:app --reload --port 7001
```

For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):

```bash
cd backend
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
poetry install
poetry shell
poetry run uvicorn main:app --reload --port 7001
```

Run the frontend:

```bash
Expand All @@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001

## Configuration

* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog

## Docker

If you have Docker installed on your system, in the root directory, run:

For OpenAI Version:

```bash
echo "OPENAI_API_KEY=sk-your-key" > .env
docker-compose up -d --build
```

For Azure version:

```bash
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
docker-compose up -d --build
```

The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.

## 🙋‍♂️ FAQs
Expand Down
23 changes: 23 additions & 0 deletions backend/api_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pydantic import BaseModel
from typing import Union, Literal, Optional


class ApiProviderInfoBase(BaseModel):
name: Literal["openai", "azure"]


class OpenAiProviderInfo(ApiProviderInfoBase):
name: Literal["openai"] = "openai" # type: ignore
api_key: str
base_url: Optional[str] = None


class AzureProviderInfo(ApiProviderInfoBase):
name: Literal["azure"] = "azure" # type: ignore
api_version: str
api_key: str
deployment_name: str
resource_name: str


ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]
32 changes: 23 additions & 9 deletions backend/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
load_dotenv()

import os
from llm import stream_openai_response
from llm import stream_openai_response, stream_azure_openai_response
from prompts import assemble_prompt
import asyncio

Expand All @@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_base_url = None
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")

pprint_prompt(prompt_messages)

async def process_chunk(content: str):
pass

if not openai_api_key and not azure_openai_api_key:
raise Exception("OpenAI API or Azure key not found")

if not openai_api_key:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you swap the condition with azure here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will check if there is one of two, is missing all raise exception (but maybe is better have 2 separate checks)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to rework this a bit today before merging.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR! Excited to get it in.

raise Exception("OpenAI API key not found")

completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
if not azure_openai_api_key:
completion = await stream_azure_openai_response(
prompt_messages,
azure_openai_api_key=azure_openai_api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_deployment_name=azure_openai_deployment_name,
callback=lambda x: process_chunk(x),
)

return completion

Expand Down
71 changes: 66 additions & 5 deletions backend/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import asyncio
import re
from typing import Dict, List, Union
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncAzureOpenAI
from bs4 import BeautifulSoup


async def process_tasks(prompts: List[str], api_key: str, base_url: str):
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
async def process_tasks(
prompts: List[str],
api_key: str | None,
base_url: str | None,
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
):
if api_key is not None:
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
if azure_openai_api_key is not None:
tasks = [
generate_image_azure(
prompt,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
for prompt in prompts
]
results = await asyncio.gather(*tasks, return_exceptions=True)

processed_results: List[Union[str, None]] = []
Expand Down Expand Up @@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
return res.data[0].url


async def generate_image_azure(
prompt: str,
azure_openai_api_key: str,
azure_openai_api_version: str,
azure_openai_resource_name: str,
azure_openai_dalle3_deployment_name: str,
):
client = AsyncAzureOpenAI(
api_version=azure_openai_api_version,
api_key=azure_openai_api_key,
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
azure_deployment=azure_openai_dalle3_deployment_name,
)
image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3",
"quality": "standard",
"style": "natural",
"n": 1,
"size": "1024x1024",
"prompt": prompt,
}
res = await client.images.generate(**image_params)
await client.close()
return res.data[0].url


def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url)
Expand Down Expand Up @@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:


async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
code: str,
api_key: str | None,
base_url: Union[str, None] | None,
image_cache: Dict[str, str],
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
):
# Find all images
soup = BeautifulSoup(code, "html.parser")
Expand Down Expand Up @@ -90,7 +143,15 @@ async def generate_images(
return code

# Generate images
results = await process_tasks(prompts, api_key, base_url)
results = await process_tasks(
prompts,
api_key,
base_url,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)

# Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results))
Expand Down
21 changes: 17 additions & 4 deletions backend/llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
from typing import Awaitable, Callable, List
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk

from api_types import ApiProviderInfo

MODEL_GPT_4_VISION = "gpt-4-vision-preview"


async def stream_openai_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
api_provider_info: ApiProviderInfo,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
if api_provider_info.name == "openai":
client = AsyncOpenAI(
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
)
elif api_provider_info.name == "azure":
client = AsyncAzureOpenAI(
api_version=api_provider_info.api_version,
api_key=api_provider_info.api_key,
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
azure_deployment=api_provider_info.deployment_name,
)
else:
raise Exception("Invalid api_provider_info")

model = MODEL_GPT_4_VISION

Expand Down
67 changes: 62 additions & 5 deletions backend/routes/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback
from fastapi import APIRouter, WebSocket
import openai
from api_types import AzureProviderInfo, OpenAiProviderInfo
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam
Expand Down Expand Up @@ -64,6 +65,12 @@ async def throw_error(
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error.
openai_api_key = None
azure_openai_api_key = None
azure_openai_resource_name = None
azure_openai_deployment_name = None
azure_openai_api_version = None
azure_openai_dalle3_deployment_name = None
azure_openai_dalle3_api_version = None
if "accessCode" in params and params["accessCode"]:
print("Access code - using platform API key")
res = await validate_access_token(params["accessCode"])
Expand All @@ -83,15 +90,29 @@ async def throw_error(
print("Using OpenAI API key from client-side settings dialog")
else:
openai_api_key = os.environ.get("OPENAI_API_KEY")
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME"
)
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
azure_openai_dalle3_deployment_name = os.environ.get(
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
)
azure_openai_dalle3_api_version = os.environ.get(
"AZURE_OPENAI_DALLE3_API_VERSION"
)
if openai_api_key:
print("Using OpenAI API key from environment variable")
if azure_openai_api_key:
print("Using Azure OpenAI API key from environment variable")

if not openai_api_key:
print("OpenAI API key not found")
if not openai_api_key and not azure_openai_api_key:
print("OpenAI API or Azure key not found")
await websocket.send_json(
{
"type": "error",
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
"value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
}
)
return
Expand Down Expand Up @@ -190,12 +211,44 @@ async def process_chunk(content: str):
completion = await mock_completion(process_chunk)
else:
try:
api_provider_info = None
if openai_api_key is not None:
api_provider_info = {
"name": "openai",
"api_key": openai_api_key,
"base_url": openai_base_url,
}

api_provider_info = OpenAiProviderInfo(
api_key=openai_api_key, base_url=openai_base_url
)

if azure_openai_api_key is not None:
if (
not azure_openai_api_version
or not azure_openai_resource_name
or not azure_openai_deployment_name
):
raise Exception(
"Missing Azure OpenAI API version, resource name, or deployment name"
)

api_provider_info = AzureProviderInfo(
api_key=azure_openai_api_key,
api_version=azure_openai_api_version,
deployment_name=azure_openai_deployment_name,
resource_name=azure_openai_resource_name,
)

if api_provider_info is None:
raise Exception("Invalid api_provider_info")

completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
api_provider_info=api_provider_info,
callback=lambda x: process_chunk(x),
)

except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
Expand Down Expand Up @@ -244,6 +297,10 @@ async def process_chunk(content: str):
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
azure_openai_api_key=azure_openai_api_key,
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
)
else:
updated_html = completion
Expand Down