Added subscribe and unsubscribe support and TCs

This commit is contained in:
Ziver Koc 2018-11-12 16:24:05 +01:00
parent f175830fae
commit 5ecfee9dfd
5 changed files with 201 additions and 65 deletions

View file

@ -1,5 +1,6 @@
package zutil.net.mqtt; package zutil.net.mqtt;
import zutil.ObjectUtil;
import zutil.log.LogUtil; import zutil.log.LogUtil;
import zutil.net.mqtt.packet.*; import zutil.net.mqtt.packet.*;
import zutil.net.mqtt.packet.MqttPacketSubscribe.MqttSubscribePayload; import zutil.net.mqtt.packet.MqttPacketSubscribe.MqttSubscribePayload;
@ -37,7 +38,20 @@ public class MqttBroker extends ThreadedTCPNetworkServer {
@Override @Override
protected ThreadedTCPNetworkServerThread getThreadInstance(Socket s) throws IOException { protected ThreadedTCPNetworkServerThread getThreadInstance(Socket s) throws IOException {
return new MqttConnectionThread(s); return new MqttConnectionThread(this, s);
}
/**
* @return the subscriber count for the specific topic, -1 if
* topic does not exist or has not been created yet.
*/
public int getSubscriberCount(String topic) {
List topicSubscriptions = subscriptions.get(topic);
if (topicSubscriptions != null) {
return topicSubscriptions.size();
}
return -1;
} }
@ -51,7 +65,7 @@ public class MqttBroker extends ThreadedTCPNetworkServer {
} }
List topicSubscriptions = subscriptions.get(topic); List topicSubscriptions = subscriptions.get(topic);
if (topicSubscriptions.contains(listener)) { if (!topicSubscriptions.contains(listener)) {
logger.finer("New subscriber on topic (" + topic + "), subscriber count: " + topicSubscriptions.size()); logger.finer("New subscriber on topic (" + topic + "), subscriber count: " + topicSubscriptions.size());
topicSubscriptions.add(listener); topicSubscriptions.add(listener);
} }
@ -85,16 +99,22 @@ public class MqttBroker extends ThreadedTCPNetworkServer {
protected static class MqttConnectionThread implements ThreadedTCPNetworkServerThread, MqttSubscriptionListener { protected static class MqttConnectionThread implements ThreadedTCPNetworkServerThread, MqttSubscriptionListener {
private MqttBroker broker;
private Socket socket; private Socket socket;
private BinaryStructInputStream in; private BinaryStructInputStream in;
private BinaryStructOutputStream out; private BinaryStructOutputStream out;
private boolean shutdown = false; private boolean shutdown = false;
/**
* Test constructor
*/
protected MqttConnectionThread(MqttBroker b) {
broker = b;
}
protected MqttConnectionThread() {} // Test constructor public MqttConnectionThread(MqttBroker b, Socket s) throws IOException {
this(b);
public MqttConnectionThread(Socket s) throws IOException {
socket = s; socket = s;
in = new BinaryStructInputStream(socket.getInputStream()); in = new BinaryStructInputStream(socket.getInputStream());
out = new BinaryStructOutputStream(socket.getOutputStream()); out = new BinaryStructOutputStream(socket.getOutputStream());
@ -106,26 +126,7 @@ public class MqttBroker extends ThreadedTCPNetworkServer {
try { try {
// Setup connection // Setup connection
MqttPacketHeader connectPacket = MqttPacket.read(in); MqttPacketHeader connectPacket = MqttPacket.read(in);
// Unexpected packet? handleConnect(connectPacket);
if (!(connectPacket instanceof MqttPacketConnect))
throw new IOException("Expected MqttPacketConnect but received " + connectPacket.getClass());
MqttPacketConnect conn = (MqttPacketConnect) connectPacket;
// Reply
MqttPacketConnectAck connectAck = new MqttPacketConnectAck();
// Incorrect protocol version?
if (conn.protocolLevel != MQTT_PROTOCOL_VERSION) {
connectAck.returnCode = MqttPacketConnectAck.RETCODE_PROT_VER_ERROR;
sendPacket(connectAck);
return;
} else {
connectAck.returnCode = MqttPacketConnectAck.RETCODE_OK;
}
// TODO: authenticate
// TODO: clean session
sendPacket(connectAck);
// Connected // Connected
@ -137,53 +138,32 @@ public class MqttBroker extends ThreadedTCPNetworkServer {
handlePacket(packet); handlePacket(packet);
} }
socket.close();
} catch (IOException e) { } catch (IOException e) {
logger.log(Level.SEVERE, null, e); logger.log(Level.SEVERE, null, e);
} finally { } finally {
try { try {
socket.close(); socket.close();
broker.unsubscribe(this);
} catch (IOException e) { } catch (IOException e) {
logger.log(Level.SEVERE, null, e); logger.log(Level.SEVERE, null, e);
} }
} }
} }
public void handlePacket(MqttPacketHeader packet) throws IOException { protected void handlePacket(MqttPacketHeader packet) throws IOException {
// TODO: QOS // TODO: QOS
switch (packet.type) { switch (packet.type) {
// TODO: Publish
case MqttPacketHeader.PACKET_TYPE_PUBLISH: case MqttPacketHeader.PACKET_TYPE_PUBLISH:
handlePublish((MqttPacketPublish) packet);
break; break;
// TODO: Subscribe
case MqttPacketHeader.PACKET_TYPE_SUBSCRIBE: case MqttPacketHeader.PACKET_TYPE_SUBSCRIBE:
MqttPacketSubscribe subscribePacket = (MqttPacketSubscribe) packet; handleSubscribe((MqttPacketSubscribe) packet);
MqttPacketSubscribeAck subscribeAckPacket = new MqttPacketSubscribeAck();
subscribeAckPacket.packetId = subscribePacket.packetId;
for (MqttSubscribePayload payload : subscribePacket.payload) {
// TODO: subscribe(payload.topicFilter, this)
MqttSubscribeAckPayload ackPayload = new MqttSubscribeAckPayload();
ackPayload.returnCode = MqttSubscribeAckPayload.RETCODE_SUCESS_MAX_QOS_0;
subscribeAckPacket.payload.add(ackPayload);
}
sendPacket(subscribeAckPacket);
break; break;
// TODO: Unsubscribe
case MqttPacketHeader.PACKET_TYPE_UNSUBSCRIBE: case MqttPacketHeader.PACKET_TYPE_UNSUBSCRIBE:
MqttPacketUnsubscribe unsubscribePacket = (MqttPacketUnsubscribe) packet; handleUnsubscribe((MqttPacketUnsubscribe) packet);
for (MqttUnsubscribePayload payload : unsubscribePacket.payload) {
// TODO: unsubscribe(payload.topicFilter, this)
}
MqttPacketUnsubscribeAck unsubscribeAckPacket = new MqttPacketUnsubscribeAck();
unsubscribeAckPacket.packetId = unsubscribePacket.packetId;
sendPacket(unsubscribeAckPacket);
break; break;
// Ping // Ping
@ -200,6 +180,61 @@ public class MqttBroker extends ThreadedTCPNetworkServer {
} }
} }
private void handleConnect(MqttPacketHeader connectPacket) throws IOException {
// Unexpected packet?
if (!(connectPacket instanceof MqttPacketConnect))
throw new IOException("Expected MqttPacketConnect but received " + connectPacket.getClass());
MqttPacketConnect conn = (MqttPacketConnect) connectPacket;
// Reply
MqttPacketConnectAck connectAck = new MqttPacketConnectAck();
// Incorrect protocol version?
if (conn.protocolLevel != MQTT_PROTOCOL_VERSION) {
connectAck.returnCode = MqttPacketConnectAck.RETCODE_PROT_VER_ERROR;
sendPacket(connectAck);
return;
} else {
connectAck.returnCode = MqttPacketConnectAck.RETCODE_OK;
}
// TODO: authenticate
// TODO: clean session
sendPacket(connectAck);
}
private void handlePublish(MqttPacketPublish publishPacket) throws IOException {
// TODO: Publish
}
private void handleSubscribe(MqttPacketSubscribe subscribePacket) throws IOException {
MqttPacketSubscribeAck subscribeAckPacket = new MqttPacketSubscribeAck();
subscribeAckPacket.packetId = subscribePacket.packetId;
for (MqttSubscribePayload payload : subscribePacket.payload) {
broker.subscribe(payload.topicFilter, this);
// Prepare response
MqttSubscribeAckPayload ackPayload = new MqttSubscribeAckPayload();
ackPayload.returnCode = MqttSubscribeAckPayload.RETCODE_SUCESS_MAX_QOS_0;
subscribeAckPacket.payload.add(ackPayload);
}
sendPacket(subscribeAckPacket);
}
private void handleUnsubscribe(MqttPacketUnsubscribe unsubscribePacket) throws IOException {
for (MqttUnsubscribePayload payload : unsubscribePacket.payload) {
broker.unsubscribe(payload.topicFilter, this);
}
// Prepare response
MqttPacketUnsubscribeAck unsubscribeAckPacket = new MqttPacketUnsubscribeAck();
unsubscribeAckPacket.packetId = unsubscribePacket.packetId;
sendPacket(unsubscribeAckPacket);
}
@Override @Override
public void dataPublished(String topic, String data) { public void dataPublished(String topic, String data) {

View file

@ -30,5 +30,5 @@ package zutil.net.mqtt;
*/ */
public interface MqttSubscriptionListener { public interface MqttSubscriptionListener {
public void dataPublished(String topic, String data); void dataPublished(String topic, String data);
} }

View file

@ -9,7 +9,7 @@ import java.io.IOException;
import static zutil.net.mqtt.packet.MqttPacketHeader.*; import static zutil.net.mqtt.packet.MqttPacketHeader.*;
/** /**
* A data class encapsulating a MQTT header and its controlHeader * A class for serializing and deserialize MQTT data packets
*/ */
public class MqttPacket { public class MqttPacket {

View file

@ -14,7 +14,6 @@ public class MqttPacketPublish extends MqttPacketHeader {
type = MqttPacketHeader.PACKET_TYPE_PUBLISH; type = MqttPacketHeader.PACKET_TYPE_PUBLISH;
} }
// Static Header
/* /*
@BinaryField(index = 2000, length = 1) @BinaryField(index = 2000, length = 1)
private int flagDup; private int flagDup;
@ -22,9 +21,6 @@ public class MqttPacketPublish extends MqttPacketHeader {
private int flagQoS; private int flagQoS;
@BinaryField(index = 2002, length = 1) @BinaryField(index = 2002, length = 1)
private int flagRetain; private int flagRetain;
@CustomBinaryField(index = 3, serializer = MqttVariableIntSerializer.class)
private int length;
*/ */
// Variable Header // Variable Header

View file

@ -3,67 +3,172 @@ package zutil.net.mqtt;
import org.junit.Test; import org.junit.Test;
import zutil.net.mqtt.MqttBroker.MqttConnectionThread; import zutil.net.mqtt.MqttBroker.MqttConnectionThread;
import zutil.net.mqtt.packet.*; import zutil.net.mqtt.packet.*;
import zutil.net.mqtt.packet.MqttPacketSubscribe.MqttSubscribePayload;
import zutil.net.mqtt.packet.MqttPacketUnsubscribe.MqttUnsubscribePayload;
import java.io.IOException; import java.io.IOException;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.Queue;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class MqttBrokerTest { public class MqttBrokerTest {
public static class MqttConnectionMockThread extends MqttConnectionThread { //**************** Mocks **************************
public static class MqttConnectionThreadMock extends MqttConnectionThread {
public LinkedList<MqttPacketHeader> sentPackets = new LinkedList<>(); public LinkedList<MqttPacketHeader> sentPackets = new LinkedList<>();
protected MqttConnectionThreadMock(MqttBroker b) {
super(b);
}
@Override @Override
public void sendPacket(MqttPacketHeader packet){ public void sendPacket(MqttPacketHeader packet) {
sentPackets.add(packet); sentPackets.add(packet);
} }
} }
//**************** Test Cases **************************
@Test @Test
public void subscribeEmpty() throws IOException { public void subscribeEmpty() throws IOException {
MqttConnectionMockThread thread = new MqttConnectionMockThread(); MqttConnectionThreadMock thread = new MqttConnectionThreadMock(new MqttBroker());
MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe(); MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe();
subscribePacket.packetId = (int)(Math.random()*1000); subscribePacket.packetId = (int)(Math.random()*1000);
thread.handlePacket(subscribePacket); thread.handlePacket(subscribePacket);
// Check response
MqttPacketHeader responsePacket = thread.sentPackets.poll(); MqttPacketHeader responsePacket = thread.sentPackets.poll();
assertEquals(MqttPacketSubscribeAck.class, responsePacket.getClass()); assertEquals(MqttPacketSubscribeAck.class, responsePacket.getClass());
assertEquals(subscribePacket.packetId, ((MqttPacketSubscribeAck)responsePacket).packetId); assertEquals(subscribePacket.packetId, ((MqttPacketSubscribeAck)responsePacket).packetId);
assertEquals(subscribePacket.payload.size(), ((MqttPacketSubscribeAck)responsePacket).payload.size());
} }
@Test @Test
public void unsubscribe() throws IOException { public void subscribe() throws IOException {
MqttConnectionMockThread thread = new MqttConnectionMockThread(); MqttBroker broker = new MqttBroker();
MqttConnectionThreadMock thread = new MqttConnectionThreadMock(broker);
MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe();
subscribePacket.packetId = (int)(Math.random()*1000);
subscribePacket.payload.add(new MqttSubscribePayload());
subscribePacket.payload.get(0).topicFilter = "topic1";
subscribePacket.payload.add(new MqttSubscribePayload());
subscribePacket.payload.get(1).topicFilter = "topic2";
thread.handlePacket(subscribePacket);
// Check response
MqttPacketHeader responsePacket = thread.sentPackets.poll();
assertEquals(MqttPacketSubscribeAck.class, responsePacket.getClass());
assertEquals(subscribePacket.packetId, ((MqttPacketSubscribeAck)responsePacket).packetId);
assertEquals(subscribePacket.payload.size(), ((MqttPacketSubscribeAck)responsePacket).payload.size());
// Check broker
assertEquals(1, broker.getSubscriberCount("topic1"));
assertEquals(1, broker.getSubscriberCount("topic2"));
//************************ Duplicate subscribe packet
subscribePacket.packetId = (int)(Math.random()*1000);
subscribePacket.payload.clear();
subscribePacket.payload.add(new MqttSubscribePayload());
subscribePacket.payload.get(0).topicFilter = "topic1";
thread.handlePacket(subscribePacket);
// Check broker
assertEquals(1, broker.getSubscriberCount("topic1"));
//************************ New subscriber
MqttConnectionThreadMock thread2 = new MqttConnectionThreadMock(broker);
thread2.handlePacket(subscribePacket);
// Check broker
assertEquals(2, broker.getSubscriberCount("topic1"));
assertEquals(1, broker.getSubscriberCount("topic2"));
}
@Test
public void unsubscribeEmpty() throws IOException {
MqttBroker broker = new MqttBroker();
MqttConnectionThreadMock thread = new MqttConnectionThreadMock(broker);
MqttPacketUnsubscribe unsubscribePacket = new MqttPacketUnsubscribe(); MqttPacketUnsubscribe unsubscribePacket = new MqttPacketUnsubscribe();
unsubscribePacket.packetId = (int)(Math.random()*1000); unsubscribePacket.packetId = (int)(Math.random()*1000);
thread.handlePacket(unsubscribePacket); thread.handlePacket(unsubscribePacket);
// Check response
MqttPacketHeader responsePacket = thread.sentPackets.poll(); MqttPacketHeader responsePacket = thread.sentPackets.poll();
assertEquals(MqttPacketUnsubscribeAck.class, responsePacket.getClass()); assertEquals(MqttPacketUnsubscribeAck.class, responsePacket.getClass());
assertEquals(unsubscribePacket.packetId, ((MqttPacketUnsubscribeAck)responsePacket).packetId); assertEquals(unsubscribePacket.packetId, ((MqttPacketUnsubscribeAck)responsePacket).packetId);
} }
@Test
public void unsubscribe() throws IOException {
MqttBroker broker = new MqttBroker();
MqttConnectionThreadMock thread = new MqttConnectionThreadMock(broker);
MqttPacketUnsubscribe unsubscribePacket = new MqttPacketUnsubscribe();
unsubscribePacket.packetId = (int)(Math.random()*1000);
unsubscribePacket.payload.add(new MqttUnsubscribePayload());
unsubscribePacket.payload.get(0).topicFilter = "topic1";
thread.handlePacket(unsubscribePacket);
// Check response
MqttPacketHeader responsePacket = thread.sentPackets.poll();
assertEquals(MqttPacketUnsubscribeAck.class, responsePacket.getClass());
assertEquals(unsubscribePacket.packetId, ((MqttPacketUnsubscribeAck)responsePacket).packetId);
// Check broker
assertEquals(-1, broker.getSubscriberCount("topic1"));
//************************ New subscriber
MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe();
subscribePacket.packetId = (int)(Math.random()*1000);
subscribePacket.payload.add(new MqttSubscribePayload());
subscribePacket.payload.get(0).topicFilter = "topic1";
subscribePacket.payload.add(new MqttSubscribePayload());
subscribePacket.payload.get(1).topicFilter = "topic2";
thread.handlePacket(subscribePacket);
// Check broker
assertEquals(1, broker.getSubscriberCount("topic1"));
//************************ Unsubscribe
unsubscribePacket.packetId = (int)(Math.random()*1000);
thread.handlePacket(unsubscribePacket);
// Check broker
assertEquals(-1, broker.getSubscriberCount("topic1"));
}
@Test @Test
public void ping() throws IOException { public void ping() throws IOException {
MqttConnectionMockThread thread = new MqttConnectionMockThread(); MqttConnectionThreadMock thread = new MqttConnectionThreadMock(new MqttBroker());
MqttPacketPingReq pingPacket = new MqttPacketPingReq(); MqttPacketPingReq pingPacket = new MqttPacketPingReq();
thread.handlePacket(pingPacket); thread.handlePacket(pingPacket);
// Check response
assertEquals(MqttPacketPingResp.class, thread.sentPackets.poll().getClass()); assertEquals(MqttPacketPingResp.class, thread.sentPackets.poll().getClass());
} }
@Test @Test
public void disconnect() throws IOException { public void disconnect() throws IOException {
MqttConnectionMockThread thread = new MqttConnectionMockThread(); MqttConnectionThreadMock thread = new MqttConnectionThreadMock(new MqttBroker());
MqttPacketDisconnect disconnectPacket = new MqttPacketDisconnect(); MqttPacketDisconnect disconnectPacket = new MqttPacketDisconnect();
thread.handlePacket(disconnectPacket); thread.handlePacket(disconnectPacket);
// Check response
assertEquals(null, thread.sentPackets.poll()); assertEquals(null, thread.sentPackets.poll());
assertTrue(thread.isShutdown()); assertTrue(thread.isShutdown());
} }