diff --git a/pyproject.toml b/pyproject.toml index dcae99d..a19c43a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,8 @@ dependencies = [ "prompt-toolkit", "yaspin", "tomlkit", - "tomli >= 1.1.0; python_version < '3.11'" + "tomli >= 1.1.0; python_version < '3.11'", + "google-generativeai" ] [project.scripts] diff --git a/src/shelloracle/providers/__init__.py b/src/shelloracle/providers/__init__.py index 7734d4e..eceb6c0 100644 --- a/src/shelloracle/providers/__init__.py +++ b/src/shelloracle/providers/__init__.py @@ -79,8 +79,16 @@ def _providers() -> dict[str, type[Provider]]: from shelloracle.providers.ollama import Ollama from shelloracle.providers.openai import OpenAI from shelloracle.providers.xai import XAI - - return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI, Deepseek.name: Deepseek} + from shelloracle.providers.google import Google + + return { + Ollama.name: Ollama, + OpenAI.name: OpenAI, + LocalAI.name: LocalAI, + XAI.name: XAI, + Deepseek.name: Deepseek, + Google.name: Google, + } def get_provider(name: str) -> type[Provider]: diff --git a/src/shelloracle/providers/google.py b/src/shelloracle/providers/google.py new file mode 100644 index 0000000..af31a7b --- /dev/null +++ b/src/shelloracle/providers/google.py @@ -0,0 +1,37 @@ +from collections.abc import AsyncIterator + +import google.generativeai as genai + +from shelloracle.providers import Provider, ProviderError, Setting, system_prompt + + +class Google(Provider): + name = "Google" + + api_key = Setting(default="") + model = Setting(default="gemini-pro") # Assuming a default model name + + def __init__(self): + if not self.api_key: + msg = "No API key provided" + raise ProviderError(msg) + genai.configure(api_key=self.api_key) + self.model_instance = genai.GenerativeModel(self.model) + + + async def generate(self, prompt: str) -> AsyncIterator[str]: + try: + response = await self.model_instance.generate_content_async( + [ + {"role": "user", "parts": [system_prompt]}, + {"role": "model", "parts": ["Okay."]}, # Gemini requires a model response before user input + {"role": "user", "parts": [prompt]}, + ], + stream=True + ) + + async for chunk in response: + yield chunk.text + except Exception as e: + msg = f"Something went wrong while querying Google Gemini: {e}" + raise ProviderError(msg) from e \ No newline at end of file