Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ public static class DbContextExtensions
{
public static void SetTenantIdFromContext(this DbContext context)
{
var tenantId = context.GetTenantIdFromContext();

var multiTenantEntities =
context.ChangeTracker.Entries()
.Where(e => e.IsMultiTenant() && e.State != EntityState.Unchanged);

if (!multiTenantEntities.Any())
{
return;
}

var tenantId = context.GetTenantIdFromContext();
foreach (var e in multiTenantEntities)
{
var attemptedTenantId = e.GetTenantId();
Expand All @@ -41,4 +45,4 @@ public static void UseMultitenancy(this DbContextOptionsBuilder options, IServic
((IDbContextOptionsBuilderInfrastructure)options).AddOrUpdateExtension(extension);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace NBB.MultiTenancy.Abstractions.Context
{
public static class TenantContextExtensions
{
public static Guid GetTenantId(this TenantContext tenantContext) => tenantContext.Tenant?.TenantId ?? throw new TenantNotFoundException();
public static Guid? TryGetTenantId(this TenantContext tenantContext) => tenantContext.Tenant?.TenantId;
public static string GetTenantCode(this TenantContext tenantContext) => tenantContext.Tenant?.Code ?? throw new TenantNotFoundException();
public static Guid GetTenantId(this TenantContext tenantContext) => tenantContext?.Tenant?.TenantId ?? throw new TenantNotFoundException();
public static Guid? TryGetTenantId(this TenantContext tenantContext) => tenantContext?.Tenant?.TenantId;
public static string GetTenantCode(this TenantContext tenantContext) => tenantContext?.Tenant?.Code ?? throw new TenantNotFoundException();

public static TenantContextFlow ChangeTenantContext(this ITenantContextAccessor tenantContextAccessor, Tenant tenant)
=> tenantContextAccessor.ChangeTenantContext(new TenantContext(tenant));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
// Copyright (c) TotalSoft.
// This source code is licensed under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
Expand All @@ -15,6 +11,10 @@
using NBB.MultiTenancy.Abstractions.Configuration;
using NBB.MultiTenancy.Abstractions.Context;
using NBB.MultiTenancy.Abstractions.Options;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

namespace NBB.Data.EntityFramework.MultiTenancy.Tests
Expand All @@ -26,7 +26,7 @@ public async Task Should_add_tenantId()
{
// arrange
var testTenantId = Guid.NewGuid();
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();
var testEntity = new TestEntity { Id = 1 };

await WithTenantScope(sp, testTenantId, async sp =>
Expand All @@ -48,7 +48,7 @@ public async Task Should_Exception_Be_Thrown_If_Different_TenantIds()
{
// arrange
var testTenantId = Guid.NewGuid();
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();
var testEntity = new TestEntity { Id = 1 };
var testEntityOtherId = new TestEntity { Id = 2 };
await WithTenantScope(sp, testTenantId, async sp =>
Expand All @@ -75,7 +75,7 @@ public async Task Shoud_Apply_Filter()
var testTenantId2 = Guid.NewGuid();
var testEntity = new TestEntity { Id = 1 };
var testEntityOtherId = new TestEntity { Id = 2 };
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();

await WithTenantScope(sp, testTenantId1, async sp =>
{
Expand Down Expand Up @@ -109,7 +109,7 @@ public async Task Should_add_TenantId_and_filter_for_MultiTenantContext()
{
// arrange
var testTenantId = Guid.NewGuid();
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();
var testEntity = new TestEntity { Id = 1 };
var testEntity1 = new TestEntity { Id = 2 };

Expand All @@ -133,13 +133,44 @@ await WithTenantScope(sp, testTenantId, async sp =>
});
}

private IServiceProvider GetServiceProvider<TDBContext>(bool isSharedDB) where TDBContext : DbContext
[Fact]
public async Task Can_Save_MultiTenantDbContext_WO_TennatContext_When_Only_NonMultiTenant_Entities_Changed()
{
// arrange
var sp = GetServiceProvider<TestDbContext>(DbStrategy.Shared);
var testEntity = new SimpleEntity { Id = 1 };
var testEntityOtherId = new SimpleEntity { Id = 2 };

var dbContext = sp.GetRequiredService<TestDbContext>();

dbContext.SimpleEntities.Add(testEntity);
dbContext.SimpleEntities.Add(testEntityOtherId);

// act
var count = await dbContext.SaveChangesAsync();

// assert
count.Should().Be(2);
}

enum DbStrategy
{
DatabasePerTenant,
Shared,
Hybrid
}

private IServiceProvider GetServiceProvider<TDBContext>(DbStrategy dbStrategy = DbStrategy.Hybrid) where TDBContext : DbContext
{
var tenantService = Mock.Of<ITenantContextAccessor>(x => x.TenantContext == null);
var isSharedDB = dbStrategy == DbStrategy.Shared;
var isHybridDB = dbStrategy == DbStrategy.Hybrid;
var connectionStringKey = isSharedDB ? "ConnectionStrings:myDb" : "MultiTenancy:Defaults:ConnectionStrings:myDb";
var connectionStringValue = isSharedDB || isHybridDB ? "Test" : Guid.NewGuid().ToString();
IConfiguration configuration = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string>
{
{ "MultiTenancy:Defaults:ConnectionStrings:myDb", isSharedDB ? "Test" : Guid.NewGuid().ToString()}
{ connectionStringKey, connectionStringValue }
})
.Build();

Expand All @@ -156,7 +187,9 @@ private IServiceProvider GetServiceProvider<TDBContext>(bool isSharedDB) where T
services.AddEntityFrameworkInMemoryDatabase()
.AddDbContext<TDBContext>((sp, options) =>
{
var conn = sp.GetRequiredService<ITenantConfiguration>().GetConnectionString("myDb");
var conn = isSharedDB ?
configuration.GetConnectionString("myDb") :
sp.GetRequiredService<ITenantConfiguration>().GetConnectionString("myDb");
options.UseInMemoryDatabase(conn).UseInternalServiceProvider(sp);
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) TotalSoft.
// This source code is licensed under the MIT license.

namespace NBB.Data.EntityFramework.MultiTenancy.Tests
{
public class SimpleEntity
{
public int Id { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) TotalSoft.
// This source code is licensed under the MIT license.

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;

namespace NBB.Data.EntityFramework.MultiTenancy.Tests
{
public class SimpleEntityConfiguration : IEntityTypeConfiguration<SimpleEntity>
{
public void Configure(EntityTypeBuilder<SimpleEntity> builder)
{
builder.HasKey(x => x.Id);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace NBB.Data.EntityFramework.MultiTenancy.Tests
public class TestDbContext : MultiTenantDbContext
{
public DbSet<TestEntity> TestEntities { get; set; }
public DbSet<SimpleEntity> SimpleEntities { get; set; }

public TestDbContext(DbContextOptions<TestDbContext> options) : base(options)
{
Expand All @@ -16,6 +17,7 @@ public TestDbContext(DbContextOptions<TestDbContext> options) : base(options)
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.ApplyConfiguration(new TestEntityConfiguration());
modelBuilder.ApplyConfiguration(new SimpleEntityConfiguration());

base.OnModelCreating(modelBuilder);
}
Expand Down
Loading