Skip to content

Commit

Permalink
openai[patch]: expose model request payload (#23287)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Jul 2, 2024
1 parent ed200bf commit cb98125
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,10 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(messages=message_dicts, **params) as response:
with self.client.create(**payload) as response:
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand Down Expand Up @@ -544,19 +543,25 @@ def _generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.client.create(messages=message_dicts, **params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = self.client.create(**payload)
return self._create_chat_result(response)

def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
messages = self._convert_input(input_).to_messages()
if stop is not None:
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
kwargs["stop"] = stop
return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
**kwargs,
}

def _create_chat_result(
self, response: Union[dict, openai.BaseModel]
Expand Down Expand Up @@ -600,11 +605,10 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(messages=message_dicts, **params)
response = await self.async_client.create(**payload)
async with response:
async for chunk in response:
if not isinstance(chunk, dict):
Expand Down Expand Up @@ -666,10 +670,8 @@ async def _agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)

message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await self.async_client.create(messages=message_dicts, **params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = await self.async_client.create(**payload)
return self._create_chat_result(response)

@property
Expand Down

0 comments on commit cb98125

Please sign in to comment.