Skip to content

Commit

Permalink
Add support for agent service providers and discovery
Browse files Browse the repository at this point in the history
Introduce interfaces and implementations for agent service providers and service discovery. Updated `KernelFactory` to integrate the new service discovery mechanism and restructured plugin discoverer naming for consistency.
  • Loading branch information
sfmskywalker committed Aug 18, 2024
1 parent 7e4ddab commit 81cbe2f
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace Elsa.Agents;

public interface IAgentServiceProvider
{
string Name { get; }
void ConfigureKernel(KernelBuilderContext context);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace Elsa.Agents;

public interface IPluginsDiscoverer
public interface IPluginDiscoverer
{
IEnumerable<PluginDescriptor> GetPluginDescriptors();
}
6 changes: 6 additions & 0 deletions src/modules/Elsa.Agents.Core/Contracts/IServiceDiscoverer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Elsa.Agents;

public interface IServiceDiscoverer
{
IEnumerable<IAgentServiceProvider> Discover();
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=entities/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=extensions/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=models/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=serviceproviders/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=services/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,9 @@ public static IServiceCollection AddPluginProvider<T>(this IServiceCollection se
{
return services.AddScoped<IPluginProvider, T>();
}

public static IServiceCollection AddAgentServiceProvider<T>(this IServiceCollection services) where T: class, IAgentServiceProvider
{
return services.AddScoped<IAgentServiceProvider, T>();
}
}
5 changes: 4 additions & 1 deletion src/modules/Elsa.Agents.Core/Features/AgentsFeature.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ public override void Configure()
Services
.AddScoped<KernelFactory>()
.AddScoped<AgentInvoker>()
.AddScoped<IPluginsDiscoverer, PluginsDiscoverer>()
.AddScoped<IPluginDiscoverer, PluginDiscoverer>()
.AddScoped<IServiceDiscoverer, ServiceDiscoverer>()
.AddScoped(_kernelConfigProviderFactory)
.AddScoped<ConfigurationKernelConfigProvider>()
.AddPluginProvider<ImageGeneratorPluginProvider>()
.AddAgentServiceProvider<OpenAIChatCompletionProvider>()
.AddAgentServiceProvider<OpenAITextToImageProvider>()
;
}
}
20 changes: 20 additions & 0 deletions src/modules/Elsa.Agents.Core/Models/KernelBuilderContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using JetBrains.Annotations;
using Microsoft.SemanticKernel;

namespace Elsa.Agents;

[UsedImplicitly]
public record KernelBuilderContext(IKernelBuilder KernelBuilder, KernelConfig KernelConfig, ServiceConfig ServiceConfig)
{
public string GetApiKey()
{
var settings = ServiceConfig.Settings;
if (settings.TryGetValue("ApiKey", out var apiKey))
return (string)apiKey;

if (settings.TryGetValue("ApiKeyRef", out var apiKeyRef))
return KernelConfig.ApiKeys[(string)apiKeyRef].Value;

throw new KeyNotFoundException($"No api key found for service {ServiceConfig.Type}");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Microsoft.SemanticKernel;

namespace Elsa.Agents;

public class OpenAIChatCompletionProvider : IAgentServiceProvider
{
public string Name => "OpenAIChatCompletion";
public void ConfigureKernel(KernelBuilderContext context)
{
var modelId = (string)context.ServiceConfig.Settings["ModelId"];
var apiKey = context.GetApiKey();
context.KernelBuilder.AddOpenAIChatCompletion(modelId, apiKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using Microsoft.SemanticKernel;
#pragma warning disable SKEXP0010

namespace Elsa.Agents;

public class OpenAITextToImageProvider : IAgentServiceProvider
{
public string Name => "OpenAITextToImage";

public void ConfigureKernel(KernelBuilderContext context)
{
var modelId = (string)context.ServiceConfig.Settings["ModelId"];
var apiKey = context.GetApiKey();
context.KernelBuilder.AddOpenAITextToImage(apiKey, modelId: modelId);
}
}
48 changes: 13 additions & 35 deletions src/modules/Elsa.Agents.Core/Services/KernelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace Elsa.Agents;

public class KernelFactory(IPluginsDiscoverer pluginsDiscoverer, ILoggerFactory loggerFactory, IServiceProvider serviceProvider, ILogger<KernelFactory> logger)
public class KernelFactory(IPluginDiscoverer pluginDiscoverer, IServiceDiscoverer serviceDiscoverer, ILoggerFactory loggerFactory, IServiceProvider serviceProvider, ILogger<KernelFactory> logger)
{
public Kernel CreateKernel(KernelConfig kernelConfig, string agentName)
{
Expand All @@ -29,6 +29,8 @@ public Kernel CreateKernel(KernelConfig kernelConfig, AgentConfig agentConfig)

private void ApplyAgentConfig(IKernelBuilder builder, KernelConfig kernelConfig, AgentConfig agentConfig)
{
var services = serviceDiscoverer.Discover().ToDictionary(x => x.Name);

foreach (string serviceName in agentConfig.Services)
{
if (!kernelConfig.Services.TryGetValue(serviceName, out var serviceConfig))
Expand All @@ -37,52 +39,28 @@ private void ApplyAgentConfig(IKernelBuilder builder, KernelConfig kernelConfig,
continue;
}

AddService(builder, kernelConfig, serviceConfig);
AddService(builder, kernelConfig, serviceConfig, services);
}

AddPlugins(builder, agentConfig);
AddAgents(builder, kernelConfig, agentConfig);
}

private void AddService(IKernelBuilder builder, KernelConfig kernelConfig, ServiceConfig serviceConfig)
private void AddService(IKernelBuilder builder, KernelConfig kernelConfig, ServiceConfig serviceConfig, Dictionary<string, IAgentServiceProvider> services)
{
switch (serviceConfig.Type)
if (!services.TryGetValue(serviceConfig.Type, out var serviceProvider))
{
case "OpenAIChatCompletion":
{
var modelId = (string)serviceConfig.Settings["ModelId"];
var apiKey = GetApiKey(kernelConfig, serviceConfig);
builder.AddOpenAIChatCompletion(modelId, apiKey);
break;
}
case "OpenAITextToImage":
{
var modelId = (string)serviceConfig.Settings["ModelId"];
var apiKey = GetApiKey(kernelConfig, serviceConfig);
builder.AddOpenAITextToImage(apiKey, modelId: modelId);
break;
}
default:
logger.LogWarning($"Unknown service type {serviceConfig.Type}");
break;
logger.LogWarning($"Service provider {serviceConfig.Type} not found");
return;
}

var context = new KernelBuilderContext(builder, kernelConfig, serviceConfig);
serviceProvider.ConfigureKernel(context);
}

private string GetApiKey(KernelConfig kernelConfig, ServiceConfig service)
{
var settings = service.Settings;
if (settings.TryGetValue("ApiKey", out var apiKey))
return (string)apiKey;

if (settings.TryGetValue("ApiKeyRef", out var apiKeyRef))
return kernelConfig.ApiKeys[(string)apiKeyRef].Value;

throw new KeyNotFoundException($"No api key found for service {service.Type}");
}


private void AddPlugins(IKernelBuilder builder, AgentConfig agent)
{
var plugins = pluginsDiscoverer.GetPluginDescriptors().ToDictionary(x => x.Name);
var plugins = pluginDiscoverer.GetPluginDescriptors().ToDictionary(x => x.Name);
foreach (var pluginName in agent.Plugins)
{
if (!plugins.TryGetValue(pluginName, out var pluginDescriptor))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace Elsa.Agents;

public class PluginsDiscoverer(IEnumerable<IPluginProvider> providers) : IPluginsDiscoverer
public class PluginDiscoverer(IEnumerable<IPluginProvider> providers) : IPluginDiscoverer
{
public IEnumerable<PluginDescriptor> GetPluginDescriptors()
{
Expand Down
9 changes: 9 additions & 0 deletions src/modules/Elsa.Agents.Core/Services/ServiceDiscoverer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Elsa.Agents;

public class ServiceDiscoverer(IEnumerable<IAgentServiceProvider> providers) : IServiceDiscoverer
{
public IEnumerable<IAgentServiceProvider> Discover()
{
return providers;
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
using System.Text.Json;
using Elsa.Extensions;

namespace Elsa.Agents.Persistence.EntityFrameworkCore;

internal static class JsonValueConverterHelper
{
private static readonly JsonSerializerOptions JsonSerializerOptions = CreateJsonSerializerOptions();

public static T Deserialize<T>(string json) where T : class
{
return string.IsNullOrWhiteSpace(json) ? null : JsonSerializer.Deserialize<T>(json);
return (string.IsNullOrWhiteSpace(json) ? null : JsonSerializer.Deserialize<T>(json, JsonSerializerOptions))!;
}

public static string Serialize<T>(T obj) where T : class
{
return obj == null ? null : JsonSerializer.Serialize(obj);
return (obj == null ? null : JsonSerializer.Serialize(obj, JsonSerializerOptions))!;
}

private static JsonSerializerOptions CreateJsonSerializerOptions()
{
return new JsonSerializerOptions().WithConverters(new PrimitiveDictionaryConverter());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Elsa.Agents.Persistence.EntityFrameworkCore;

public class PrimitiveDictionaryConverter : JsonConverter<IDictionary<string, object>>
{
public override IDictionary<string, object> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (reader.TokenType != JsonTokenType.StartObject)
throw new JsonException("Expected start of object.");

var dictionary = new Dictionary<string, object>();

while (reader.Read())
{
if (reader.TokenType == JsonTokenType.EndObject)
return dictionary;

var key = reader.GetString()!;
reader.Read();
var value = ReadValue(ref reader, options);
dictionary.Add(key, value);
}

throw new JsonException("Expected end of object.");
}

private object ReadValue(ref Utf8JsonReader reader, JsonSerializerOptions options)
{
switch (reader.TokenType)
{
case JsonTokenType.String:
return reader.GetString()!;
case JsonTokenType.Number:
if (reader.TryGetInt64(out var l))
return l;
return reader.GetDouble();
case JsonTokenType.True:
return true;
case JsonTokenType.False:
return false;
case JsonTokenType.Null:
return null!;
case JsonTokenType.StartObject:
return JsonSerializer.Deserialize<Dictionary<string, object>>(ref reader, options)!;
case JsonTokenType.StartArray:
return JsonSerializer.Deserialize<List<object>>(ref reader, options)!;
default:
using (var document = JsonDocument.ParseValue(ref reader))
{
return document.RootElement.Clone().ToString();
}
}
}

public override void Write(Utf8JsonWriter writer, IDictionary<string, object> value, JsonSerializerOptions options)
{
writer.WriteStartObject();

foreach (var kvp in value)
{
writer.WritePropertyName(kvp.Key);
JsonSerializer.Serialize(writer, kvp.Value, options);
}

writer.WriteEndObject();
}
}

0 comments on commit 81cbe2f

Please sign in to comment.