Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions src/SharpLinks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,16 @@ public IContext Initialize(IContext context, LdapConfig options) {

if (string.IsNullOrWhiteSpace(context.DomainName)) {
if (!context.LDAPUtils.GetDomain(out var d)) {
context.Logger.LogCritical("unable to get current domain");
context.Flags.IsFaulted = true;
if (TryInferDomainName(options, out var inferredDomain)) {
context.DomainName = inferredDomain;
context.Logger.LogInformation(
"Unable to resolve current domain from host context, inferred domain {Domain} from supplied connection options",
inferredDomain);
} else {
context.Logger.LogCritical(
"unable to get current domain; specify --domain or provide ldap credentials/domain controller that include domain information");
context.Flags.IsFaulted = true;
}
} else {
context.DomainName = d.Name;
context.Logger.LogInformation("Resolved current domain to {Domain}", d.Name);
Expand Down Expand Up @@ -203,28 +211,34 @@ public async Task<IContext> GetDomainsForEnumeration(IContext context) {
return context;
}

if (!context.LDAPUtils.GetDomain(context.DomainName, out var domainObject)) {
context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling");
context.Flags.IsFaulted = true;
return context;
}

var hasDomainObject = context.LDAPUtils.GetDomain(context.DomainName, out var domainObject);
var domain = domainObject?.Name ?? context.DomainName;
if (domain == null) {
context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling");
context.Flags.IsFaulted = true;
return context;
}

if (domainObject != null && domainObject.GetDirectoryEntry().ToDirectoryObject()
if (hasDomainObject && domainObject != null && domainObject.GetDirectoryEntry().ToDirectoryObject()
.TryGetSecurityIdentifier(out var sid)) {
context.Domains = new[] {
new EnumerationDomain {
Name = domain,
DomainSid = sid
}
};
} else if (await context.LDAPUtils.GetDomainSidFromDomainName(domain) is (true, var resolvedSid)) {
context.Logger.LogDebug("Resolved domain SID for {Domain} through LDAP", domain);
context.Domains = new[] {
new EnumerationDomain {
Name = domain,
DomainSid = resolvedSid
}
};
} else {
context.Logger.LogWarning(
"Could not resolve domain object for {Domain}; continuing with provided domain and unknown SID",
domain);
context.Domains = new[] {
new EnumerationDomain {
Name = domain,
Expand All @@ -237,6 +251,39 @@ public async Task<IContext> GetDomainsForEnumeration(IContext context) {
return context;
}

private static bool TryInferDomainName(LdapConfig options, out string domainName) {
domainName = null;
if (options == null) {
return false;
}

if (!string.IsNullOrWhiteSpace(options.Username)) {
var username = options.Username.Trim();
var atIndex = username.IndexOf('@');
if (atIndex > 0 && atIndex < username.Length - 1) {
domainName = username.Substring(atIndex + 1);
return true;
}

var slashIndex = username.IndexOf('\\');
if (slashIndex > 0) {
domainName = username.Substring(0, slashIndex);
return true;
}
}

if (!string.IsNullOrWhiteSpace(options.Server)) {
var server = options.Server.Trim();
var firstDot = server.IndexOf('.');
if (firstDot > 0 && firstDot < server.Length - 1) {
domainName = server.Substring(firstDot + 1);
return true;
}
Comment on lines +275 to +281
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Parse options.Server before domain extraction to avoid invalid inferred domains.

Line 277 currently parses the raw server string. If Server contains a port or IP, inference can produce invalid domain values and break downstream SID lookup.

Proposed fix
             if (!string.IsNullOrWhiteSpace(options.Server)) {
-                var server = options.Server.Trim();
-                var firstDot = server.IndexOf('.');
-                if (firstDot > 0 && firstDot < server.Length - 1) {
-                    domainName = server.Substring(firstDot + 1);
+                var server = options.Server.Trim();
+                var host = server;
+
+                // Handle absolute forms like ldap://dc.test.local:389
+                if (server.Contains("://") && Uri.TryCreate(server, UriKind.Absolute, out var uri)) {
+                    host = uri.Host;
+                } else {
+                    // Handle host:port without scheme
+                    var colon = server.LastIndexOf(':');
+                    if (colon > 0 && server.IndexOf(':') == colon) {
+                        host = server.Substring(0, colon);
+                    }
+                }
+
+                // Skip IP literals; they cannot be used to infer a DNS domain safely
+                if (System.Net.IPAddress.TryParse(host, out _)) {
+                    return false;
+                }
+
+                var firstDot = host.IndexOf('.');
+                if (firstDot > 0 && firstDot < host.Length - 1) {
+                    domainName = host.Substring(firstDot + 1);
                     return true;
                 }
             }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (!string.IsNullOrWhiteSpace(options.Server)) {
var server = options.Server.Trim();
var firstDot = server.IndexOf('.');
if (firstDot > 0 && firstDot < server.Length - 1) {
domainName = server.Substring(firstDot + 1);
return true;
}
if (!string.IsNullOrWhiteSpace(options.Server)) {
var server = options.Server.Trim();
var host = server;
// Handle absolute forms like ldap://dc.test.local:389
if (server.Contains("://") && Uri.TryCreate(server, UriKind.Absolute, out var uri)) {
host = uri.Host;
} else {
// Handle host:port without scheme
var colon = server.LastIndexOf(':');
if (colon > 0 && server.IndexOf(':') == colon) {
host = server.Substring(0, colon);
}
}
// Skip IP literals; they cannot be used to infer a DNS domain safely
if (System.Net.IPAddress.TryParse(host, out _)) {
return false;
}
var firstDot = host.IndexOf('.');
if (firstDot > 0 && firstDot < host.Length - 1) {
domainName = host.Substring(firstDot + 1);
return true;
}
}

}

return false;
}

private async IAsyncEnumerable<EnumerationDomain> BuildRecursiveDomainList(IContext context) {
var domainResults = new List<EnumerationDomain>();
var enumeratedDomains = new HashSet<string>();
Expand Down
Loading