using Xunit;
using Moq;
using Microsoft.AspNetCore.Http;
using OpenHarbor.MCP.Gateway.AspNetCore.Middleware;
using OpenHarbor.MCP.Gateway.Core.Interfaces;
using OpenHarbor.MCP.Gateway.Core.Models;
using System.Text;
using System.Text.Json;
namespace OpenHarbor.MCP.Gateway.AspNetCore.Tests.Middleware;
///
/// Unit tests for GatewayMiddleware following TDD approach.
/// Tests HTTP request interception and gateway routing.
///
public class GatewayMiddlewareTests
{
[Fact]
public async Task InvokeAsync_WithGatewayRequest_RoutesToGateway()
{
// Arrange
var mockRouter = new Mock();
mockRouter.Setup(r => r.RouteAsync(It.IsAny(), It.IsAny()))
.ReturnsAsync(new GatewayResponse { Success = true, Result = new Dictionary { ["data"] = "test" } });
var middleware = new GatewayMiddleware(
next: (HttpContext _) => Task.CompletedTask,
router: mockRouter.Object
);
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp/invoke";
context.Request.ContentType = "application/json";
var requestBody = JsonSerializer.Serialize(new { toolName = "test_tool", arguments = new { } });
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody));
context.Response.Body = new MemoryStream();
// Act
await middleware.InvokeAsync(context);
// Assert
mockRouter.Verify(r => r.RouteAsync(It.Is(req => req.ToolName == "test_tool"), It.IsAny()), Times.Once);
Assert.Equal(200, context.Response.StatusCode);
}
[Fact]
public async Task InvokeAsync_WithNonGatewayPath_CallsNext()
{
// Arrange
var mockRouter = new Mock();
var nextCalled = false;
var middleware = new GatewayMiddleware(
next: (HttpContext _) => { nextCalled = true; return Task.CompletedTask; },
router: mockRouter.Object
);
var context = new DefaultHttpContext();
context.Request.Method = "GET";
context.Request.Path = "/api/other";
// Act
await middleware.InvokeAsync(context);
// Assert
Assert.True(nextCalled);
mockRouter.Verify(r => r.RouteAsync(It.IsAny(), It.IsAny()), Times.Never);
}
[Fact]
public async Task InvokeAsync_WithGatewayError_Returns500()
{
// Arrange
var mockRouter = new Mock();
mockRouter.Setup(r => r.RouteAsync(It.IsAny(), It.IsAny()))
.ReturnsAsync(new GatewayResponse { Success = false, Error = "Routing failed", ErrorCode = "ROUTE_ERROR" });
var middleware = new GatewayMiddleware(
next: (HttpContext _) => Task.CompletedTask,
router: mockRouter.Object
);
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp/invoke";
context.Request.ContentType = "application/json";
var requestBody = JsonSerializer.Serialize(new { toolName = "test_tool" });
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody));
context.Response.Body = new MemoryStream();
// Act
await middleware.InvokeAsync(context);
// Assert
Assert.Equal(500, context.Response.StatusCode);
}
[Fact]
public async Task InvokeAsync_WithInvalidJson_Returns400()
{
// Arrange
var mockRouter = new Mock();
var middleware = new GatewayMiddleware(
next: (HttpContext _) => Task.CompletedTask,
router: mockRouter.Object
);
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp/invoke";
context.Request.ContentType = "application/json";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("invalid json"));
context.Response.Body = new MemoryStream();
// Act
await middleware.InvokeAsync(context);
// Assert
Assert.Equal(400, context.Response.StatusCode);
}
[Fact]
public async Task InvokeAsync_ExtractsClientIdFromHeader()
{
// Arrange
var mockRouter = new Mock();
mockRouter.Setup(r => r.RouteAsync(It.IsAny(), It.IsAny()))
.ReturnsAsync(new GatewayResponse { Success = true });
var middleware = new GatewayMiddleware(
next: (HttpContext _) => Task.CompletedTask,
router: mockRouter.Object
);
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp/invoke";
context.Request.ContentType = "application/json";
context.Request.Headers["X-Client-Id"] = "test-client";
var requestBody = JsonSerializer.Serialize(new { toolName = "test_tool" });
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody));
context.Response.Body = new MemoryStream();
// Act
await middleware.InvokeAsync(context);
// Assert
mockRouter.Verify(r => r.RouteAsync(
It.Is(req => req.ClientId == "test-client"),
It.IsAny()), Times.Once);
}
[Fact]
public async Task InvokeAsync_ReturnsJsonResponse()
{
// Arrange
var mockRouter = new Mock();
var expectedResult = new Dictionary { ["status"] = "success", ["value"] = 42 };
mockRouter.Setup(r => r.RouteAsync(It.IsAny(), It.IsAny()))
.ReturnsAsync(new GatewayResponse { Success = true, Result = expectedResult });
var middleware = new GatewayMiddleware(
next: (HttpContext _) => Task.CompletedTask,
router: mockRouter.Object
);
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp/invoke";
context.Request.ContentType = "application/json";
var requestBody = JsonSerializer.Serialize(new { toolName = "test_tool" });
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody));
context.Response.Body = new MemoryStream();
// Act
await middleware.InvokeAsync(context);
// Assert
Assert.Equal("application/json", context.Response.ContentType);
context.Response.Body.Seek(0, SeekOrigin.Begin);
var responseBody = await new StreamReader(context.Response.Body).ReadToEndAsync();
Assert.Contains("success", responseBody);
Assert.Contains("42", responseBody);
}
}