diff --git a/src/SharpLinks.cs b/src/SharpLinks.cs index baa1321..fe2eeb0 100644 --- a/src/SharpLinks.cs +++ b/src/SharpLinks.cs @@ -60,8 +60,16 @@ public IContext Initialize(IContext context, LdapConfig options) { if (string.IsNullOrWhiteSpace(context.DomainName)) { if (!context.LDAPUtils.GetDomain(out var d)) { - context.Logger.LogCritical("unable to get current domain"); - context.Flags.IsFaulted = true; + if (TryInferDomainName(options, out var inferredDomain)) { + context.DomainName = inferredDomain; + context.Logger.LogInformation( + "Unable to resolve current domain from host context, inferred domain {Domain} from supplied connection options", + inferredDomain); + } else { + context.Logger.LogCritical( + "unable to get current domain; specify --domain or provide ldap credentials/domain controller that include domain information"); + context.Flags.IsFaulted = true; + } } else { context.DomainName = d.Name; context.Logger.LogInformation("Resolved current domain to {Domain}", d.Name); @@ -203,12 +211,7 @@ public async Task GetDomainsForEnumeration(IContext context) { return context; } - if (!context.LDAPUtils.GetDomain(context.DomainName, out var domainObject)) { - context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling"); - context.Flags.IsFaulted = true; - return context; - } - + var hasDomainObject = context.LDAPUtils.GetDomain(context.DomainName, out var domainObject); var domain = domainObject?.Name ?? context.DomainName; if (domain == null) { context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling"); @@ -216,7 +219,7 @@ public async Task GetDomainsForEnumeration(IContext context) { return context; } - if (domainObject != null && domainObject.GetDirectoryEntry().ToDirectoryObject() + if (hasDomainObject && domainObject != null && domainObject.GetDirectoryEntry().ToDirectoryObject() .TryGetSecurityIdentifier(out var sid)) { context.Domains = new[] { new EnumerationDomain { @@ -224,7 +227,18 @@ public async Task GetDomainsForEnumeration(IContext context) { DomainSid = sid } }; + } else if (await context.LDAPUtils.GetDomainSidFromDomainName(domain) is (true, var resolvedSid)) { + context.Logger.LogDebug("Resolved domain SID for {Domain} through LDAP", domain); + context.Domains = new[] { + new EnumerationDomain { + Name = domain, + DomainSid = resolvedSid + } + }; } else { + context.Logger.LogWarning( + "Could not resolve domain object for {Domain}; continuing with provided domain and unknown SID", + domain); context.Domains = new[] { new EnumerationDomain { Name = domain, @@ -237,6 +251,39 @@ public async Task GetDomainsForEnumeration(IContext context) { return context; } + private static bool TryInferDomainName(LdapConfig options, out string domainName) { + domainName = null; + if (options == null) { + return false; + } + + if (!string.IsNullOrWhiteSpace(options.Username)) { + var username = options.Username.Trim(); + var atIndex = username.IndexOf('@'); + if (atIndex > 0 && atIndex < username.Length - 1) { + domainName = username.Substring(atIndex + 1); + return true; + } + + var slashIndex = username.IndexOf('\\'); + if (slashIndex > 0) { + domainName = username.Substring(0, slashIndex); + return true; + } + } + + if (!string.IsNullOrWhiteSpace(options.Server)) { + var server = options.Server.Trim(); + var firstDot = server.IndexOf('.'); + if (firstDot > 0 && firstDot < server.Length - 1) { + domainName = server.Substring(firstDot + 1); + return true; + } + } + + return false; + } + private async IAsyncEnumerable BuildRecursiveDomainList(IContext context) { var domainResults = new List(); var enumeratedDomains = new HashSet();