diff --git a/src/zutil/net/dns/MulticastDnsServer.java b/src/zutil/net/dns/MulticastDnsServer.java index 2a203c4..c9a02bf 100755 --- a/src/zutil/net/dns/MulticastDnsServer.java +++ b/src/zutil/net/dns/MulticastDnsServer.java @@ -32,10 +32,13 @@ import zutil.net.dns.packet.DnsPacketResource; import zutil.net.threaded.ThreadedUDPNetwork; import zutil.net.threaded.ThreadedUDPNetworkThread; import zutil.parser.binary.BinaryStructInputStream; +import zutil.parser.binary.BinaryStructOutputStream; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.DatagramPacket; +import java.net.InetAddress; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; @@ -64,10 +67,18 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD public MulticastDnsServer() throws IOException { super(MDNS_MULTICAST_ADDR, MDNS_MULTICAST_PORT); setThread( this ); - } + /** + * Add a domain name specific data that will be returned to a requesting client + * + * @param name is the domain name to add the entry under + * @param ip the IPv4 address to respond with + */ + public void addEntry(String name, InetAddress ip){ + addEntry(name, DnsConstants.TYPE.A, DnsConstants.CLASS.IN, ip.getAddress()); + } /** * Add a domain name specific data that will be returned to a requesting client * @@ -90,7 +101,7 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD private void addEntry(DnsPacketResource resource) { if ( ! entries.containsKey(resource.name)) - entries.put(resource.name, new ArrayList<>()); + entries.put(resource.name, new ArrayList()); entries.get(resource.name).add(resource); } @@ -106,20 +117,18 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD // Just handle queries and no responses if ( ! dnsPacket.getHeader().flagQueryResponse){ - for (DnsPacketQuestion question : dnsPacket.getQuestions()){ - if (question.name == null) continue; - - switch(question.type){ - case DnsConstants.TYPE.PTR: - if (question.name.startsWith("_service")){ - - } - else if (entries.containsKey(question.name)){ - // Respond with entries - - } - break; - } + DnsPacket response = handleReceivedPacket(dnsPacket); + if (response != null){ + ByteArrayOutputStream outBuffer = new ByteArrayOutputStream(); + BinaryStructOutputStream out = new BinaryStructOutputStream(outBuffer); + response.write(out); + out.close(); + + DatagramPacket outPacket = new DatagramPacket( + outBuffer.toByteArray(), outBuffer.size(), + InetAddress.getByName( MDNS_MULTICAST_ADDR ), + MDNS_MULTICAST_PORT ); + send(outPacket); } } } catch (IOException e){ @@ -127,4 +136,38 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD } } + protected DnsPacket handleReceivedPacket(DnsPacket request){ + DnsPacket response = new DnsPacket(); + response.getHeader().setDefaultResponseData(); + for (DnsPacketQuestion question : request.getQuestions()){ + if (question.name == null) continue; + switch (question.type){ + + // Normal Domain Name Resolution + case DnsConstants.TYPE.A: + if (entries.containsKey(question.name)){ + response.addAnswerRecord(entries.get(question.name)); + } + break; + + // Service Name Resolution + case DnsConstants.TYPE.PTR: + if (question.name.startsWith("_service.")){ + String postFix = question.name.substring(9); + for (String domain : entries.keySet()){ + if (domain.endsWith(postFix)) + response.addAnswerRecord(entries.get(domain)); + } + } else if (entries.containsKey(question.name)){ + response.addAnswerRecord(entries.get(question.name)); + } + break; + } + } + if (response.getAnswerRecords().isEmpty() && + response.getNameServers().isEmpty() && + response.getAdditionalRecords().isEmpty()) + return null; + return response; + } } diff --git a/src/zutil/net/dns/packet/DnsPacket.java b/src/zutil/net/dns/packet/DnsPacket.java index 793bb74..b08669a 100755 --- a/src/zutil/net/dns/packet/DnsPacket.java +++ b/src/zutil/net/dns/packet/DnsPacket.java @@ -56,6 +56,10 @@ public class DnsPacket { answerRecords.add(resource); header.countAnswerRecord = answerRecords.size(); } + public void addAnswerRecord(List resources){ + answerRecords.addAll(resources); + header.countAnswerRecord = answerRecords.size(); + } public void addNameServer(DnsPacketResource resource){ nameServers.add(resource); header.countNameServer = nameServers.size(); diff --git a/src/zutil/parser/binary/BinaryStructInputStream.java b/src/zutil/parser/binary/BinaryStructInputStream.java index 997b035..2246cdc 100755 --- a/src/zutil/parser/binary/BinaryStructInputStream.java +++ b/src/zutil/parser/binary/BinaryStructInputStream.java @@ -52,9 +52,15 @@ public class BinaryStructInputStream { * Parses a byte array and assigns all fields in the struct */ public static int read(BinaryStruct struct, byte[] data) { + return read(struct, data, 0, data.length); + } + /** + * Parses a byte array and assigns all fields in the struct + */ + public static int read(BinaryStruct struct, byte[] data, int offset, int length) { int read = 0; try { - ByteArrayInputStream buffer = new ByteArrayInputStream(data); + ByteArrayInputStream buffer = new ByteArrayInputStream(data, offset, length); BinaryStructInputStream in = new BinaryStructInputStream(buffer); read = in.read(struct); } catch (Exception e){ diff --git a/test/zutil/net/dns/MulticastDnsServerTest.java b/test/zutil/net/dns/MulticastDnsServerTest.java new file mode 100755 index 0000000..f577e16 --- /dev/null +++ b/test/zutil/net/dns/MulticastDnsServerTest.java @@ -0,0 +1,80 @@ +package zutil.net.dns; + +import org.junit.Test; +import zutil.net.dns.packet.DnsConstants; +import zutil.net.dns.packet.DnsPacket; +import zutil.net.dns.packet.DnsPacketQuestion; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; + +import static org.junit.Assert.*; + +/** + * + */ +public class MulticastDnsServerTest { + + private MulticastDnsServer server = new MulticastDnsServer(); + public MulticastDnsServerTest() throws IOException {} + + + @Test + public void domainLookupNoEntries(){ + DnsPacket request = creatRequestDnsPacket( + "example.com", + DnsConstants.TYPE.A); + + DnsPacket response = server.handleReceivedPacket(request); + assertNull(response); + } + + + @Test + public void domainLookup() throws UnknownHostException { + DnsPacket request = creatRequestDnsPacket( + "example.com", + DnsConstants.TYPE.A); + + server.addEntry("example.com", InetAddress.getLocalHost()); + DnsPacket response = server.handleReceivedPacket(request); + assertNotNull(response); + assertEquals("example.com", response.getAnswerRecords().get(0).name); + } + + + @Test + public void serviceDiscoveryNoEntries(){ + DnsPacket request = creatRequestDnsPacket( + "_service._tcp.local", + DnsConstants.TYPE.PTR); + + DnsPacket response = server.handleReceivedPacket(request); + assertNull(response); + } + + + @Test + public void serviceDiscovery() throws UnknownHostException { + DnsPacket request = creatRequestDnsPacket( + "_service._tcp.local", + DnsConstants.TYPE.PTR); + + server.addEntry("_http._tcp.local", InetAddress.getLocalHost()); + DnsPacket response = server.handleReceivedPacket(request); + assertNotNull(response); + } + + + private static DnsPacket creatRequestDnsPacket(String domain, int type){ + DnsPacket request = new DnsPacket(); + request.getHeader().setDefaultQueryData(); + DnsPacketQuestion question = new DnsPacketQuestion(); + question.name = domain; + question.type = type; + question.clazz = DnsConstants.CLASS.IN; + request.addQuestion(question); + return request; + } +} \ No newline at end of file