diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs index 9ba36410a30f..9865f8ea8100 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs @@ -8,7 +8,6 @@ using Microsoft.Extensions.Logging; namespace Microsoft.AutoGen.Agents; - public sealed class GrpcGateway : BackgroundService, IGateway { private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(30); @@ -16,13 +15,12 @@ public sealed class GrpcGateway : BackgroundService, IGateway private readonly IClusterClient _clusterClient; private readonly ConcurrentDictionary _agentState = new(); private readonly IRegistryGrain _gatewayRegistry; - private readonly ISubscriptionsGrain _subscriptions; + private readonly ISubscriptionsGrain _subscriptionsGrain; private readonly IGateway _reference; // The agents supported by each worker process. + private SubscriptionsState _subscriptionsState = new(); private readonly ConcurrentDictionary> _supportedAgentTypes = []; public readonly ConcurrentDictionary _workers = new(); - private readonly ConcurrentDictionary _subscriptionsByAgentType = new(); - private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); // The mapping from agent id to worker process. private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new(); @@ -36,7 +34,7 @@ public GrpcGateway(IClusterClient clusterClient, ILogger logger) _clusterClient = clusterClient; _reference = clusterClient.CreateObjectReference(this); _gatewayRegistry = clusterClient.GetGrain(0); - _subscriptions = clusterClient.GetGrain(0); + _subscriptionsGrain = clusterClient.GetGrain(0); } public async ValueTask BroadcastEvent(CloudEvent evt) { @@ -135,10 +133,9 @@ private async ValueTask AddSubscriptionAsync(GrpcWorkerConnection connection, Ad topic = request.Subscription.TypeSubscription.TopicType; agentType = request.Subscription.TypeSubscription.AgentType; } - _subscriptionsByAgentType[agentType] = request.Subscription; - _subscriptionsByTopic.GetOrAdd(topic, _ => []).Add(agentType); - await _subscriptions.SubscribeAsync(topic, agentType); - //var response = new AddSubscriptionResponse { RequestId = request.RequestId, Error = "", Success = true }; + await _subscriptionsGrain.SubscribeAsync(topic, agentType).ConfigureAwait(true); + _subscriptionsState = await _subscriptionsGrain.GetSubscriptionsStateAsync().ConfigureAwait(true); + Message response = new() { AddSubscriptionResponse = new() @@ -169,12 +166,13 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, } private async ValueTask DispatchEventAsync(CloudEvent evt) { + var _subscriptionsByTopic = await _subscriptionsGrain.GetSubscriptionsByTopicAsync().ConfigureAwait(true); // get the event type and then send to all agents that are subscribed to that event type var eventType = evt.Type; // ensure that we get agentTypes as an async enumerable list - try to get the value of agentTypes by topic and then cast it to an async enumerable list if (_subscriptionsByTopic.TryGetValue(eventType, out var agentTypes)) { - await DispatchEventToAgentsAsync(agentTypes, evt); + await DispatchEventToAgentsAsync(agentTypes, evt).ConfigureAwait(false); } // instead of an exact match, we can also check for a prefix match where key starts with the eventType else if (_subscriptionsByTopic.Keys.Any(key => key.StartsWith(eventType))) diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcWorkerConnection.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcWorkerConnection.cs index f2eb81c43602..fdb9a5c0d872 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcWorkerConnection.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcWorkerConnection.cs @@ -7,7 +7,7 @@ namespace Microsoft.AutoGen.Agents; -internal sealed class GrpcWorkerConnection : IAsyncDisposable, IConnection +public sealed class GrpcWorkerConnection : IAsyncDisposable, IConnection { private static long s_nextConnectionId; private readonly Task _readTask; diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs index 9905f6aebac6..4ee3d4abcae4 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs @@ -17,7 +17,7 @@ public async ValueTask WriteStateAsync(AgentState newState, string eTag, if ((string.IsNullOrEmpty(state.Etag)) || (string.IsNullOrEmpty(eTag)) || (string.Equals(state.Etag, eTag, StringComparison.Ordinal))) { state.State = newState; - await state.WriteStateAsync().ConfigureAwait(false); + await state.WriteStateAsync().ConfigureAwait(true); } else { diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/ISubscriptionsGrain.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/ISubscriptionsGrain.cs index d3af459bb7ff..566ba8d84f4f 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/ISubscriptionsGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/ISubscriptionsGrain.cs @@ -1,10 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ISubscriptionsGrain.cs +using System.Collections.Concurrent; +using Microsoft.AutoGen.Abstractions; + namespace Microsoft.AutoGen.Agents; + +[Alias("Microsoft.AutoGen.Agents.ISubscriptionsGrain")] public interface ISubscriptionsGrain : IGrainWithIntegerKey { + [Alias("SubscribeAsync")] ValueTask SubscribeAsync(string agentType, string topic); + [Alias("UnsubscribeAsync")] ValueTask UnsubscribeAsync(string agentType, string topic); - ValueTask>> GetSubscriptions(string agentType); + [Alias("GetSubscriptionsAsync")] + ValueTask>> GetSubscriptionsByAgentTypeAsync(string? agentType = null); + [Alias ("GetSubscriptionsByTopicAsync")] + ValueTask>> GetSubscriptionsByTopicAsync(string? topic = null); + [Alias("GetSubscriptionsByAgentTypeAsync")] + ValueTask GetSubscriptionsStateAsync(); + [Alias("WriteSubscriptionsStateAsync")] + ValueTask WriteSubscriptionsStateAsync(SubscriptionsState subscriptionsState); } diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsGrain.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsGrain.cs index 0e647dbab980..e8a6dc175532 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsGrain.cs @@ -1,50 +1,74 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SubscriptionsGrain.cs +using System.Collections.Concurrent; +using Microsoft.AutoGen.Abstractions; namespace Microsoft.AutoGen.Agents; -internal sealed class SubscriptionsGrain([PersistentState("state", "PubSubStore")] IPersistentState state) : Grain, ISubscriptionsGrain +internal sealed class SubscriptionsGrain([PersistentState("state", "PubSubStore")] IPersistentState subscriptionsState) : Grain, ISubscriptionsGrain { - private readonly Dictionary> _subscriptions = new(); - public ValueTask>> GetSubscriptions(string? agentType = null) + private readonly IPersistentState _subscriptionsState = subscriptionsState; + + public ValueTask>> GetSubscriptionsByAgentTypeAsync(string? agentType = null) { + var _subscriptions = _subscriptionsState.State.SubscriptionsByAgentType; //if agentType is null, return all subscriptions else filter on agentType if (agentType != null) { - return new ValueTask>>(_subscriptions.Where(x => x.Value.Contains(agentType)).ToDictionary(x => x.Key, x => x.Value)); + var filteredSubscriptions = _subscriptions.Where(x => x.Value.Contains(agentType)); + return ValueTask.FromResult>>((ConcurrentDictionary>)filteredSubscriptions); } - return new ValueTask>>(_subscriptions); + return ValueTask.FromResult>>(_subscriptions); } - public async ValueTask SubscribeAsync(string agentType, string topic) + public ValueTask>> GetSubscriptionsByTopicAsync(string? topic = null) { - if (!_subscriptions.TryGetValue(topic, out var subscriptions)) - { - subscriptions = _subscriptions[topic] = []; - } - if (!subscriptions.Contains(agentType)) + var _subscriptions = _subscriptionsState.State.SubscriptionsByTopic; + //if topic is null, return all subscriptions else filter on topic + if (topic != null) { - subscriptions.Add(agentType); + var filteredSubscriptions = _subscriptions.Where(x => x.Key == topic); + return ValueTask.FromResult>>((ConcurrentDictionary>)filteredSubscriptions); } - _subscriptions[topic] = subscriptions; - state.State.Subscriptions = _subscriptions; - await state.WriteStateAsync().ConfigureAwait(false); + return ValueTask.FromResult>>(_subscriptions); + } + public ValueTask GetSubscriptionsStateAsync() => ValueTask.FromResult(_subscriptionsState.State); + + public async ValueTask SubscribeAsync(string agentType, string topic) + { + await WriteSubscriptionsAsync(agentType: agentType, topic: topic, subscribe: true).ConfigureAwait(false); } public async ValueTask UnsubscribeAsync(string agentType, string topic) { - if (!_subscriptions.TryGetValue(topic, out var subscriptions)) + await WriteSubscriptionsAsync(agentType: agentType, topic: topic, subscribe: false).ConfigureAwait(false); + } + public async ValueTask WriteSubscriptionsStateAsync(SubscriptionsState subscriptionsState) + { + _subscriptionsState.State = subscriptionsState; + await _subscriptionsState.WriteStateAsync().ConfigureAwait(true); + } + + private async ValueTask WriteSubscriptionsAsync(string agentType, string topic, bool subscribe=true) + { + var _subscriptions = await GetSubscriptionsByAgentTypeAsync().ConfigureAwait(true); + if (!_subscriptions.TryGetValue(topic, out var agentTypes)) { - subscriptions = _subscriptions[topic] = []; + agentTypes = _subscriptions[topic] = []; } - if (!subscriptions.Contains(agentType)) + if (!agentTypes.Contains(agentType)) { - subscriptions.Remove(agentType); + if (subscribe) + { + agentTypes.Add(agentType); + } + else + { + agentTypes.Remove(agentType); + } } - _subscriptions[topic] = subscriptions; - state.State.Subscriptions = _subscriptions; - await state.WriteStateAsync(); + _subscriptionsState.State.SubscriptionsByAgentType = _subscriptions; + var _subsByTopic = await GetSubscriptionsByTopicAsync().ConfigureAwait(true); + _subsByTopic.GetOrAdd(topic, _ => []).Add(agentType); + _subscriptionsState.State.SubscriptionsByTopic = _subsByTopic; + await _subscriptionsState.WriteStateAsync().ConfigureAwait(false); } -} -public sealed class SubscriptionsState -{ - public Dictionary> Subscriptions { get; set; } = new(); -} +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsState.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsState.cs new file mode 100644 index 000000000000..eb1bcddc38ae --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/SubscriptionsState.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SubscriptionsState.cs +using System.Collections.Concurrent; + +namespace Microsoft.AutoGen.Abstractions; +[GenerateSerializer] +[Serializable] +public sealed class SubscriptionsState +{ + public ConcurrentDictionary> SubscriptionsByTopic = new(); + public ConcurrentDictionary> SubscriptionsByAgentType { get; set; } = new(); +} \ No newline at end of file diff --git a/dotnet/test/Microsoft.AutoGen.Agents.Tests/ISubscriptionsGrainTests.cs b/dotnet/test/Microsoft.AutoGen.Agents.Tests/ISubscriptionsGrainTests.cs new file mode 100644 index 000000000000..02dcb75571c4 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Agents.Tests/ISubscriptionsGrainTests.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ISubscriptionsGrainTests.cs + +using System.Collections.Concurrent; +using Microsoft.AutoGen.Abstractions; +using Moq; +using Xunit; + +namespace Microsoft.AutoGen.Agents.Tests; + +public class ISubscriptionsGrainTests +{ + private readonly Mock _mockSubscriptionsGrain; + + public ISubscriptionsGrainTests() + { + _mockSubscriptionsGrain = new Mock(); + } + + [Fact] + public async Task GetSubscriptionsStateAsync_ReturnsCorrectState() + { + // Arrange + var subscriptionsState = new SubscriptionsState + { + SubscriptionsByAgentType = new ConcurrentDictionary> + { + ["topic1"] = ["agentType1"], + ["topic2"] = ["agentType2"] + } + }; + _mockSubscriptionsGrain.Setup(grain => grain.GetSubscriptionsStateAsync()).ReturnsAsync(subscriptionsState); + + // Act + var result = await _mockSubscriptionsGrain.Object.GetSubscriptionsStateAsync(); + + // Assert + Assert.Equal(2, result.SubscriptionsByAgentType.Count); + Assert.Contains("topic1", result.SubscriptionsByAgentType.Keys); + Assert.Contains("topic2", result.SubscriptionsByAgentType.Keys); + } + + [Fact] + public async Task GetSubscriptions_ReturnsAllSubscriptions_WhenAgentTypeIsNull() + { + // Arrange + var subscriptions = new ConcurrentDictionary>(); + subscriptions.TryAdd("topic1", new List { "agentType1" }); + subscriptions.TryAdd("topic2", new List { "agentType2" }); + _mockSubscriptionsGrain.Setup(grain => grain.GetSubscriptionsByAgentTypeAsync(null)).ReturnsAsync(subscriptions); + + // Act + var result = await _mockSubscriptionsGrain.Object.GetSubscriptionsByAgentTypeAsync(); + + // Assert + Assert.Equal(2, result.Count); + Assert.Contains("topic1", result.Keys); + Assert.Contains("topic2", result.Keys); + } + + [Fact] + public async Task GetSubscriptions_ReturnsFilteredSubscriptions_WhenAgentTypeIsNotNull() + { + // Arrange + var subscriptions = new ConcurrentDictionary>(); + subscriptions.TryAdd("topic1", new List { "agentType1" }); + _mockSubscriptionsGrain.Setup(grain => grain.GetSubscriptionsByAgentTypeAsync("agentType1")).ReturnsAsync(subscriptions); + + // Act + var result = await _mockSubscriptionsGrain.Object.GetSubscriptionsByAgentTypeAsync("agentType1"); + + // Assert + Assert.Single(result); + Assert.Contains("topic1", result.Keys); + } + + [Fact] + public async Task SubscribeAsync_AddsSubscription() + { + // Arrange + _mockSubscriptionsGrain.Setup(grain => grain.SubscribeAsync("agentType1", "topic1")).Returns(ValueTask.CompletedTask); + + // Act + await _mockSubscriptionsGrain.Object.SubscribeAsync("agentType1", "topic1"); + + // Assert + _mockSubscriptionsGrain.Verify(grain => grain.SubscribeAsync("agentType1", "topic1"), Times.Once); + } + + [Fact] + public async Task UnsubscribeAsync_RemovesSubscription() + { + // Arrange + _mockSubscriptionsGrain.Setup(grain => grain.UnsubscribeAsync("agentType1", "topic1")).Returns(ValueTask.CompletedTask); + + // Act + await _mockSubscriptionsGrain.Object.UnsubscribeAsync("agentType1", "topic1"); + + // Assert + _mockSubscriptionsGrain.Verify(grain => grain.UnsubscribeAsync("agentType1", "topic1"), Times.Once); + } +} \ No newline at end of file