From 5ecfee9dfd60573a7d869af44d641addff1e775f Mon Sep 17 00:00:00 2001 From: Ziver Koc Date: Mon, 12 Nov 2018 16:24:05 +0100 Subject: [PATCH] Added subscribe and unsubscribe support and TCs --- src/zutil/net/mqtt/MqttBroker.java | 137 +++++++++++------- .../net/mqtt/MqttSubscriptionListener.java | 2 +- src/zutil/net/mqtt/packet/MqttPacket.java | 2 +- .../net/mqtt/packet/MqttPacketPublish.java | 4 - test/zutil/net/mqtt/MqttBrokerTest.java | 121 +++++++++++++++- 5 files changed, 201 insertions(+), 65 deletions(-) diff --git a/src/zutil/net/mqtt/MqttBroker.java b/src/zutil/net/mqtt/MqttBroker.java index ee5ace4..140d419 100755 --- a/src/zutil/net/mqtt/MqttBroker.java +++ b/src/zutil/net/mqtt/MqttBroker.java @@ -1,5 +1,6 @@ package zutil.net.mqtt; +import zutil.ObjectUtil; import zutil.log.LogUtil; import zutil.net.mqtt.packet.*; import zutil.net.mqtt.packet.MqttPacketSubscribe.MqttSubscribePayload; @@ -37,7 +38,20 @@ public class MqttBroker extends ThreadedTCPNetworkServer { @Override protected ThreadedTCPNetworkServerThread getThreadInstance(Socket s) throws IOException { - return new MqttConnectionThread(s); + return new MqttConnectionThread(this, s); + } + + + /** + * @return the subscriber count for the specific topic, -1 if + * topic does not exist or has not been created yet. + */ + public int getSubscriberCount(String topic) { + List topicSubscriptions = subscriptions.get(topic); + if (topicSubscriptions != null) { + return topicSubscriptions.size(); + } + return -1; } @@ -51,7 +65,7 @@ public class MqttBroker extends ThreadedTCPNetworkServer { } List topicSubscriptions = subscriptions.get(topic); - if (topicSubscriptions.contains(listener)) { + if (!topicSubscriptions.contains(listener)) { logger.finer("New subscriber on topic (" + topic + "), subscriber count: " + topicSubscriptions.size()); topicSubscriptions.add(listener); } @@ -85,16 +99,22 @@ public class MqttBroker extends ThreadedTCPNetworkServer { protected static class MqttConnectionThread implements ThreadedTCPNetworkServerThread, MqttSubscriptionListener { + private MqttBroker broker; private Socket socket; private BinaryStructInputStream in; private BinaryStructOutputStream out; private boolean shutdown = false; + /** + * Test constructor + */ + protected MqttConnectionThread(MqttBroker b) { + broker = b; + } - protected MqttConnectionThread() {} // Test constructor - - public MqttConnectionThread(Socket s) throws IOException { + public MqttConnectionThread(MqttBroker b, Socket s) throws IOException { + this(b); socket = s; in = new BinaryStructInputStream(socket.getInputStream()); out = new BinaryStructOutputStream(socket.getOutputStream()); @@ -106,26 +126,7 @@ public class MqttBroker extends ThreadedTCPNetworkServer { try { // Setup connection MqttPacketHeader connectPacket = MqttPacket.read(in); - // Unexpected packet? - if (!(connectPacket instanceof MqttPacketConnect)) - throw new IOException("Expected MqttPacketConnect but received " + connectPacket.getClass()); - MqttPacketConnect conn = (MqttPacketConnect) connectPacket; - - // Reply - MqttPacketConnectAck connectAck = new MqttPacketConnectAck(); - - // Incorrect protocol version? - if (conn.protocolLevel != MQTT_PROTOCOL_VERSION) { - connectAck.returnCode = MqttPacketConnectAck.RETCODE_PROT_VER_ERROR; - sendPacket(connectAck); - return; - } else { - connectAck.returnCode = MqttPacketConnectAck.RETCODE_OK; - } - - // TODO: authenticate - // TODO: clean session - sendPacket(connectAck); + handleConnect(connectPacket); // Connected @@ -137,53 +138,32 @@ public class MqttBroker extends ThreadedTCPNetworkServer { handlePacket(packet); } - socket.close(); } catch (IOException e) { logger.log(Level.SEVERE, null, e); } finally { try { socket.close(); + broker.unsubscribe(this); } catch (IOException e) { logger.log(Level.SEVERE, null, e); } } } - public void handlePacket(MqttPacketHeader packet) throws IOException { + protected void handlePacket(MqttPacketHeader packet) throws IOException { // TODO: QOS switch (packet.type) { - // TODO: Publish case MqttPacketHeader.PACKET_TYPE_PUBLISH: + handlePublish((MqttPacketPublish) packet); break; - // TODO: Subscribe case MqttPacketHeader.PACKET_TYPE_SUBSCRIBE: - MqttPacketSubscribe subscribePacket = (MqttPacketSubscribe) packet; - MqttPacketSubscribeAck subscribeAckPacket = new MqttPacketSubscribeAck(); - subscribeAckPacket.packetId = subscribePacket.packetId; - - for (MqttSubscribePayload payload : subscribePacket.payload) { - // TODO: subscribe(payload.topicFilter, this) - - MqttSubscribeAckPayload ackPayload = new MqttSubscribeAckPayload(); - ackPayload.returnCode = MqttSubscribeAckPayload.RETCODE_SUCESS_MAX_QOS_0; - subscribeAckPacket.payload.add(ackPayload); - } - sendPacket(subscribeAckPacket); + handleSubscribe((MqttPacketSubscribe) packet); break; - // TODO: Unsubscribe case MqttPacketHeader.PACKET_TYPE_UNSUBSCRIBE: - MqttPacketUnsubscribe unsubscribePacket = (MqttPacketUnsubscribe) packet; - - for (MqttUnsubscribePayload payload : unsubscribePacket.payload) { - // TODO: unsubscribe(payload.topicFilter, this) - } - - MqttPacketUnsubscribeAck unsubscribeAckPacket = new MqttPacketUnsubscribeAck(); - unsubscribeAckPacket.packetId = unsubscribePacket.packetId; - sendPacket(unsubscribeAckPacket); + handleUnsubscribe((MqttPacketUnsubscribe) packet); break; // Ping @@ -200,6 +180,61 @@ public class MqttBroker extends ThreadedTCPNetworkServer { } } + private void handleConnect(MqttPacketHeader connectPacket) throws IOException { + // Unexpected packet? + if (!(connectPacket instanceof MqttPacketConnect)) + throw new IOException("Expected MqttPacketConnect but received " + connectPacket.getClass()); + MqttPacketConnect conn = (MqttPacketConnect) connectPacket; + + // Reply + MqttPacketConnectAck connectAck = new MqttPacketConnectAck(); + + // Incorrect protocol version? + if (conn.protocolLevel != MQTT_PROTOCOL_VERSION) { + connectAck.returnCode = MqttPacketConnectAck.RETCODE_PROT_VER_ERROR; + sendPacket(connectAck); + return; + } else { + connectAck.returnCode = MqttPacketConnectAck.RETCODE_OK; + } + + // TODO: authenticate + // TODO: clean session + sendPacket(connectAck); + } + + private void handlePublish(MqttPacketPublish publishPacket) throws IOException { + // TODO: Publish + } + + private void handleSubscribe(MqttPacketSubscribe subscribePacket) throws IOException { + MqttPacketSubscribeAck subscribeAckPacket = new MqttPacketSubscribeAck(); + subscribeAckPacket.packetId = subscribePacket.packetId; + + for (MqttSubscribePayload payload : subscribePacket.payload) { + broker.subscribe(payload.topicFilter, this); + + // Prepare response + MqttSubscribeAckPayload ackPayload = new MqttSubscribeAckPayload(); + ackPayload.returnCode = MqttSubscribeAckPayload.RETCODE_SUCESS_MAX_QOS_0; + subscribeAckPacket.payload.add(ackPayload); + } + + sendPacket(subscribeAckPacket); + } + + private void handleUnsubscribe(MqttPacketUnsubscribe unsubscribePacket) throws IOException { + for (MqttUnsubscribePayload payload : unsubscribePacket.payload) { + broker.unsubscribe(payload.topicFilter, this); + } + + // Prepare response + MqttPacketUnsubscribeAck unsubscribeAckPacket = new MqttPacketUnsubscribeAck(); + unsubscribeAckPacket.packetId = unsubscribePacket.packetId; + sendPacket(unsubscribeAckPacket); + } + + @Override public void dataPublished(String topic, String data) { diff --git a/src/zutil/net/mqtt/MqttSubscriptionListener.java b/src/zutil/net/mqtt/MqttSubscriptionListener.java index f3f1f17..392261a 100644 --- a/src/zutil/net/mqtt/MqttSubscriptionListener.java +++ b/src/zutil/net/mqtt/MqttSubscriptionListener.java @@ -30,5 +30,5 @@ package zutil.net.mqtt; */ public interface MqttSubscriptionListener { - public void dataPublished(String topic, String data); + void dataPublished(String topic, String data); } diff --git a/src/zutil/net/mqtt/packet/MqttPacket.java b/src/zutil/net/mqtt/packet/MqttPacket.java index 4de1a62..3fafedf 100755 --- a/src/zutil/net/mqtt/packet/MqttPacket.java +++ b/src/zutil/net/mqtt/packet/MqttPacket.java @@ -9,7 +9,7 @@ import java.io.IOException; import static zutil.net.mqtt.packet.MqttPacketHeader.*; /** - * A data class encapsulating a MQTT header and its controlHeader + * A class for serializing and deserialize MQTT data packets */ public class MqttPacket { diff --git a/src/zutil/net/mqtt/packet/MqttPacketPublish.java b/src/zutil/net/mqtt/packet/MqttPacketPublish.java index fb27a63..7e34fe2 100755 --- a/src/zutil/net/mqtt/packet/MqttPacketPublish.java +++ b/src/zutil/net/mqtt/packet/MqttPacketPublish.java @@ -14,7 +14,6 @@ public class MqttPacketPublish extends MqttPacketHeader { type = MqttPacketHeader.PACKET_TYPE_PUBLISH; } - // Static Header /* @BinaryField(index = 2000, length = 1) private int flagDup; @@ -22,9 +21,6 @@ public class MqttPacketPublish extends MqttPacketHeader { private int flagQoS; @BinaryField(index = 2002, length = 1) private int flagRetain; - - @CustomBinaryField(index = 3, serializer = MqttVariableIntSerializer.class) - private int length; */ // Variable Header diff --git a/test/zutil/net/mqtt/MqttBrokerTest.java b/test/zutil/net/mqtt/MqttBrokerTest.java index e3ddadf..97703f3 100644 --- a/test/zutil/net/mqtt/MqttBrokerTest.java +++ b/test/zutil/net/mqtt/MqttBrokerTest.java @@ -3,67 +3,172 @@ package zutil.net.mqtt; import org.junit.Test; import zutil.net.mqtt.MqttBroker.MqttConnectionThread; import zutil.net.mqtt.packet.*; +import zutil.net.mqtt.packet.MqttPacketSubscribe.MqttSubscribePayload; +import zutil.net.mqtt.packet.MqttPacketUnsubscribe.MqttUnsubscribePayload; import java.io.IOException; import java.util.LinkedList; -import java.util.Queue; import static org.junit.Assert.*; public class MqttBrokerTest { - public static class MqttConnectionMockThread extends MqttConnectionThread { + //**************** Mocks ************************** + + public static class MqttConnectionThreadMock extends MqttConnectionThread { public LinkedList sentPackets = new LinkedList<>(); + protected MqttConnectionThreadMock(MqttBroker b) { + super(b); + } + @Override - public void sendPacket(MqttPacketHeader packet){ + public void sendPacket(MqttPacketHeader packet) { sentPackets.add(packet); } } + //**************** Test Cases ************************** + @Test public void subscribeEmpty() throws IOException { - MqttConnectionMockThread thread = new MqttConnectionMockThread(); + MqttConnectionThreadMock thread = new MqttConnectionThreadMock(new MqttBroker()); MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe(); subscribePacket.packetId = (int)(Math.random()*1000); thread.handlePacket(subscribePacket); + // Check response MqttPacketHeader responsePacket = thread.sentPackets.poll(); assertEquals(MqttPacketSubscribeAck.class, responsePacket.getClass()); assertEquals(subscribePacket.packetId, ((MqttPacketSubscribeAck)responsePacket).packetId); + assertEquals(subscribePacket.payload.size(), ((MqttPacketSubscribeAck)responsePacket).payload.size()); } @Test - public void unsubscribe() throws IOException { - MqttConnectionMockThread thread = new MqttConnectionMockThread(); + public void subscribe() throws IOException { + MqttBroker broker = new MqttBroker(); + MqttConnectionThreadMock thread = new MqttConnectionThreadMock(broker); + MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe(); + subscribePacket.packetId = (int)(Math.random()*1000); + + subscribePacket.payload.add(new MqttSubscribePayload()); + subscribePacket.payload.get(0).topicFilter = "topic1"; + subscribePacket.payload.add(new MqttSubscribePayload()); + subscribePacket.payload.get(1).topicFilter = "topic2"; + + thread.handlePacket(subscribePacket); + + // Check response + MqttPacketHeader responsePacket = thread.sentPackets.poll(); + assertEquals(MqttPacketSubscribeAck.class, responsePacket.getClass()); + assertEquals(subscribePacket.packetId, ((MqttPacketSubscribeAck)responsePacket).packetId); + assertEquals(subscribePacket.payload.size(), ((MqttPacketSubscribeAck)responsePacket).payload.size()); + // Check broker + assertEquals(1, broker.getSubscriberCount("topic1")); + assertEquals(1, broker.getSubscriberCount("topic2")); + + //************************ Duplicate subscribe packet + + subscribePacket.packetId = (int)(Math.random()*1000); + subscribePacket.payload.clear(); + subscribePacket.payload.add(new MqttSubscribePayload()); + subscribePacket.payload.get(0).topicFilter = "topic1"; + + thread.handlePacket(subscribePacket); + + // Check broker + assertEquals(1, broker.getSubscriberCount("topic1")); + + //************************ New subscriber + + MqttConnectionThreadMock thread2 = new MqttConnectionThreadMock(broker); + + thread2.handlePacket(subscribePacket); + + // Check broker + assertEquals(2, broker.getSubscriberCount("topic1")); + assertEquals(1, broker.getSubscriberCount("topic2")); + } + + @Test + public void unsubscribeEmpty() throws IOException { + MqttBroker broker = new MqttBroker(); + MqttConnectionThreadMock thread = new MqttConnectionThreadMock(broker); MqttPacketUnsubscribe unsubscribePacket = new MqttPacketUnsubscribe(); unsubscribePacket.packetId = (int)(Math.random()*1000); thread.handlePacket(unsubscribePacket); + // Check response MqttPacketHeader responsePacket = thread.sentPackets.poll(); assertEquals(MqttPacketUnsubscribeAck.class, responsePacket.getClass()); assertEquals(unsubscribePacket.packetId, ((MqttPacketUnsubscribeAck)responsePacket).packetId); } + @Test + public void unsubscribe() throws IOException { + MqttBroker broker = new MqttBroker(); + MqttConnectionThreadMock thread = new MqttConnectionThreadMock(broker); + MqttPacketUnsubscribe unsubscribePacket = new MqttPacketUnsubscribe(); + unsubscribePacket.packetId = (int)(Math.random()*1000); + + unsubscribePacket.payload.add(new MqttUnsubscribePayload()); + unsubscribePacket.payload.get(0).topicFilter = "topic1"; + + thread.handlePacket(unsubscribePacket); + + // Check response + MqttPacketHeader responsePacket = thread.sentPackets.poll(); + assertEquals(MqttPacketUnsubscribeAck.class, responsePacket.getClass()); + assertEquals(unsubscribePacket.packetId, ((MqttPacketUnsubscribeAck)responsePacket).packetId); + // Check broker + assertEquals(-1, broker.getSubscriberCount("topic1")); + + //************************ New subscriber + + MqttPacketSubscribe subscribePacket = new MqttPacketSubscribe(); + subscribePacket.packetId = (int)(Math.random()*1000); + + subscribePacket.payload.add(new MqttSubscribePayload()); + subscribePacket.payload.get(0).topicFilter = "topic1"; + subscribePacket.payload.add(new MqttSubscribePayload()); + subscribePacket.payload.get(1).topicFilter = "topic2"; + + thread.handlePacket(subscribePacket); + + // Check broker + assertEquals(1, broker.getSubscriberCount("topic1")); + + //************************ Unsubscribe + + unsubscribePacket.packetId = (int)(Math.random()*1000); + + thread.handlePacket(unsubscribePacket); + + // Check broker + assertEquals(-1, broker.getSubscriberCount("topic1")); + } + @Test public void ping() throws IOException { - MqttConnectionMockThread thread = new MqttConnectionMockThread(); + MqttConnectionThreadMock thread = new MqttConnectionThreadMock(new MqttBroker()); MqttPacketPingReq pingPacket = new MqttPacketPingReq(); thread.handlePacket(pingPacket); + // Check response assertEquals(MqttPacketPingResp.class, thread.sentPackets.poll().getClass()); } @Test public void disconnect() throws IOException { - MqttConnectionMockThread thread = new MqttConnectionMockThread(); + MqttConnectionThreadMock thread = new MqttConnectionThreadMock(new MqttBroker()); MqttPacketDisconnect disconnectPacket = new MqttPacketDisconnect(); thread.handlePacket(disconnectPacket); + // Check response assertEquals(null, thread.sentPackets.poll()); assertTrue(thread.isShutdown()); }