Refactore MDNS so it can be tested, and also added Junit test for it

This commit is contained in:
Ziver Koc 2017-02-07 18:41:35 +01:00
parent 6a3744eb99
commit 84a3b13b76
4 changed files with 150 additions and 17 deletions

View file

@ -32,10 +32,13 @@ import zutil.net.dns.packet.DnsPacketResource;
import zutil.net.threaded.ThreadedUDPNetwork;
import zutil.net.threaded.ThreadedUDPNetworkThread;
import zutil.parser.binary.BinaryStructInputStream;
import zutil.parser.binary.BinaryStructOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
@ -64,10 +67,18 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD
public MulticastDnsServer() throws IOException {
super(MDNS_MULTICAST_ADDR, MDNS_MULTICAST_PORT);
setThread( this );
}
/**
* Add a domain name specific data that will be returned to a requesting client
*
* @param name is the domain name to add the entry under
* @param ip the IPv4 address to respond with
*/
public void addEntry(String name, InetAddress ip){
addEntry(name, DnsConstants.TYPE.A, DnsConstants.CLASS.IN, ip.getAddress());
}
/**
* Add a domain name specific data that will be returned to a requesting client
*
@ -90,7 +101,7 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD
private void addEntry(DnsPacketResource resource) {
if ( ! entries.containsKey(resource.name))
entries.put(resource.name, new ArrayList<>());
entries.put(resource.name, new ArrayList<DnsPacketResource>());
entries.get(resource.name).add(resource);
}
@ -106,20 +117,18 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD
// Just handle queries and no responses
if ( ! dnsPacket.getHeader().flagQueryResponse){
for (DnsPacketQuestion question : dnsPacket.getQuestions()){
if (question.name == null) continue;
switch(question.type){
case DnsConstants.TYPE.PTR:
if (question.name.startsWith("_service")){
}
else if (entries.containsKey(question.name)){
// Respond with entries
}
break;
}
DnsPacket response = handleReceivedPacket(dnsPacket);
if (response != null){
ByteArrayOutputStream outBuffer = new ByteArrayOutputStream();
BinaryStructOutputStream out = new BinaryStructOutputStream(outBuffer);
response.write(out);
out.close();
DatagramPacket outPacket = new DatagramPacket(
outBuffer.toByteArray(), outBuffer.size(),
InetAddress.getByName( MDNS_MULTICAST_ADDR ),
MDNS_MULTICAST_PORT );
send(outPacket);
}
}
} catch (IOException e){
@ -127,4 +136,38 @@ public class MulticastDnsServer extends ThreadedUDPNetwork implements ThreadedUD
}
}
protected DnsPacket handleReceivedPacket(DnsPacket request){
DnsPacket response = new DnsPacket();
response.getHeader().setDefaultResponseData();
for (DnsPacketQuestion question : request.getQuestions()){
if (question.name == null) continue;
switch (question.type){
// Normal Domain Name Resolution
case DnsConstants.TYPE.A:
if (entries.containsKey(question.name)){
response.addAnswerRecord(entries.get(question.name));
}
break;
// Service Name Resolution
case DnsConstants.TYPE.PTR:
if (question.name.startsWith("_service.")){
String postFix = question.name.substring(9);
for (String domain : entries.keySet()){
if (domain.endsWith(postFix))
response.addAnswerRecord(entries.get(domain));
}
} else if (entries.containsKey(question.name)){
response.addAnswerRecord(entries.get(question.name));
}
break;
}
}
if (response.getAnswerRecords().isEmpty() &&
response.getNameServers().isEmpty() &&
response.getAdditionalRecords().isEmpty())
return null;
return response;
}
}

View file

@ -56,6 +56,10 @@ public class DnsPacket {
answerRecords.add(resource);
header.countAnswerRecord = answerRecords.size();
}
public void addAnswerRecord(List<DnsPacketResource> resources){
answerRecords.addAll(resources);
header.countAnswerRecord = answerRecords.size();
}
public void addNameServer(DnsPacketResource resource){
nameServers.add(resource);
header.countNameServer = nameServers.size();

View file

@ -52,9 +52,15 @@ public class BinaryStructInputStream {
* Parses a byte array and assigns all fields in the struct
*/
public static int read(BinaryStruct struct, byte[] data) {
return read(struct, data, 0, data.length);
}
/**
* Parses a byte array and assigns all fields in the struct
*/
public static int read(BinaryStruct struct, byte[] data, int offset, int length) {
int read = 0;
try {
ByteArrayInputStream buffer = new ByteArrayInputStream(data);
ByteArrayInputStream buffer = new ByteArrayInputStream(data, offset, length);
BinaryStructInputStream in = new BinaryStructInputStream(buffer);
read = in.read(struct);
} catch (Exception e){

View file

@ -0,0 +1,80 @@
package zutil.net.dns;
import org.junit.Test;
import zutil.net.dns.packet.DnsConstants;
import zutil.net.dns.packet.DnsPacket;
import zutil.net.dns.packet.DnsPacketQuestion;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import static org.junit.Assert.*;
/**
*
*/
public class MulticastDnsServerTest {
private MulticastDnsServer server = new MulticastDnsServer();
public MulticastDnsServerTest() throws IOException {}
@Test
public void domainLookupNoEntries(){
DnsPacket request = creatRequestDnsPacket(
"example.com",
DnsConstants.TYPE.A);
DnsPacket response = server.handleReceivedPacket(request);
assertNull(response);
}
@Test
public void domainLookup() throws UnknownHostException {
DnsPacket request = creatRequestDnsPacket(
"example.com",
DnsConstants.TYPE.A);
server.addEntry("example.com", InetAddress.getLocalHost());
DnsPacket response = server.handleReceivedPacket(request);
assertNotNull(response);
assertEquals("example.com", response.getAnswerRecords().get(0).name);
}
@Test
public void serviceDiscoveryNoEntries(){
DnsPacket request = creatRequestDnsPacket(
"_service._tcp.local",
DnsConstants.TYPE.PTR);
DnsPacket response = server.handleReceivedPacket(request);
assertNull(response);
}
@Test
public void serviceDiscovery() throws UnknownHostException {
DnsPacket request = creatRequestDnsPacket(
"_service._tcp.local",
DnsConstants.TYPE.PTR);
server.addEntry("_http._tcp.local", InetAddress.getLocalHost());
DnsPacket response = server.handleReceivedPacket(request);
assertNotNull(response);
}
private static DnsPacket creatRequestDnsPacket(String domain, int type){
DnsPacket request = new DnsPacket();
request.getHeader().setDefaultQueryData();
DnsPacketQuestion question = new DnsPacketQuestion();
question.name = domain;
question.type = type;
question.clazz = DnsConstants.CLASS.IN;
request.addQuestion(question);
return request;
}
}