Skip to content

Commit

Permalink
partners: AI21 Labs Jamba Streaming Support (#23538)
Browse files Browse the repository at this point in the history
Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"

- [x] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** Added support for streaming in AI21 Jamba Model
    - **Twitter handle:** https://github.com/AI21Labs


- [x] **Add tests and docs**: If you're adding a new integration, please
include

- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

---------

Co-authored-by: Asaf Gardin <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
4 people committed Jul 2, 2024
1 parent 5cd4083 commit 320dc31
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 36 deletions.
29 changes: 28 additions & 1 deletion libs/partners/ai21/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,40 @@ Then initialize
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21

chat = ChatAI21(model="jamba-instruct-preview")
chat = ChatAI21(model="jamba-instruct")
messages = [HumanMessage(content="Hello from AI21")]
chat.invoke(messages)
```

For a list of the supported models, see [this page](https://docs.ai21.com/reference/python-sdk#chat)

### Streaming in Chat
Streaming is supported by the latest models. To use streaming, set the `streaming` parameter to `True` when initializing the model.

```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21

chat = ChatAI21(model="jamba-instruct", streaming=True)
messages = [HumanMessage(content="Hello from AI21")]

response = chat.invoke(messages)
```

or use the `stream` method directly

```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21

chat = ChatAI21(model="jamba-instruct")
messages = [HumanMessage(content="Hello from AI21")]

for chunk in chat.stream(messages):
print(chunk)
```


## LLMs
You can use AI21's Jurassic generative AI models as LangChain LLMs.
To use the newer Jamba model, use the [ChatAI21 chat model](#chat-models), which
Expand Down
127 changes: 120 additions & 7 deletions libs/partners/ai21/langchain_ai21/chat/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload

from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import ChatMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from ai21.models.chat import ChatCompletionChunk, ChatMessage
from ai21.stream.stream import Stream as AI21Stream
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGenerationChunk

_ChatMessageTypes = Union[ChatMessage, J2ChatMessage]
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
Expand Down Expand Up @@ -63,8 +72,31 @@ def _chat_message(
) -> _ChatMessageTypes:
pass

@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
pass

@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
pass

@abstractmethod
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
pass

def _get_system_message_from_message(self, message: BaseMessage) -> str:
Expand Down Expand Up @@ -102,7 +134,33 @@ def _chat_message(
) -> J2ChatMessage:
return J2ChatMessage(role=RoleType(role), text=content)

def call(self, client: Any, **params: Any) -> List[BaseMessage]:
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
...

@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
...

def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
if stream:
raise NotImplementedError("Streaming is not supported for Jurassic models.")

response = client.chat.create(**params)

return [AIMessage(output.text) for output in response.outputs]
Expand All @@ -128,7 +186,62 @@ def _chat_message(
content=content,
)

def call(self, client: Any, **params: Any) -> List[BaseMessage]:
response = client.chat.completions.create(**params)
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
...

@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
...

def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
response = client.chat.completions.create(stream=stream, **params)

if stream:
return self._stream_response(response)

return [AIMessage(choice.message.content) for choice in response.choices]

def _stream_response(
self,
response: AI21Stream[ChatCompletionChunk],
) -> Iterator[ChatGenerationChunk]:
for chunk in response:
converted_message = self._convert_ai21_chunk_to_chunk(chunk)
yield ChatGenerationChunk(message=converted_message)

def _convert_ai21_chunk_to_chunk(
self,
chunk: ChatCompletionChunk,
) -> BaseMessageChunk:
usage = chunk.usage
content = chunk.choices[0].delta.content or ""

if usage is None:
return AIMessageChunk(
content=content,
)

return AIMessageChunk(
content=content,
usage_metadata=UsageMetadata(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
),
)
64 changes: 60 additions & 4 deletions libs/partners/ai21/langchain_ai21/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import asyncio
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
generate_from_stream,
)
from langchain_core.messages import (
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator

from langchain_ai21.ai21_base import AI21Base
Expand Down Expand Up @@ -75,6 +79,7 @@ class ChatAI21(BaseChatModel, AI21Base):

n: int = 1
"""Number of chat completions to generate for each prompt."""
streaming: bool = False

_chat_adapter: ChatAdapter

Expand Down Expand Up @@ -166,14 +171,65 @@ def _generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
should_stream = stream or self.streaming

if should_stream:
return self._handle_stream_from_generate(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)

params = self._build_params_for_request(
messages=messages,
stop=stop,
stream=should_stream,
**kwargs,
)

messages = self._chat_adapter.call(self.client, **params)
generations = [ChatGeneration(message=message) for message in messages]

return ChatResult(generations=generations)

def _handle_stream_from_generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
stream_iter = self._stream(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
return generate_from_stream(stream_iter)

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._build_params_for_request(
messages=messages,
stop=stop,
stream=True,
**kwargs,
)

for chunk in self._chat_adapter.call(self.client, **params):
if run_manager and isinstance(chunk.message.content, str):
run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
yield chunk

async def _agenerate(
self,
messages: List[BaseMessage],
Expand Down
31 changes: 17 additions & 14 deletions libs/partners/ai21/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/partners/ai21/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ license = "MIT"
python = ">=3.8.1,<4.0"
langchain-core = "^0.2.4"
langchain-text-splitters = "^0.2.0"
ai21 = "^2.4.1"
ai21 = "^2.7.0"

[tool.poetry.group.test]
optional = true
Expand Down
Loading

0 comments on commit 320dc31

Please sign in to comment.