Add test cases for local ID assignment race conditions
authorMitja Nikolaus <mitja@fairphone.com>
Tue, 13 Nov 2018 15:48:20 +0000 (16:48 +0100)
committerMitja Nikolaus <mitja@fairphone.com>
Fri, 7 Dec 2018 12:54:03 +0000 (12:54 +0000)
Issue: HIC-251
Change-Id: I9616b786a44bba2621a8a800035ff1189d3d9b19

crashreports/tests/test_rest_api_crashreports.py
crashreports/tests/test_rest_api_devices.py
crashreports/tests/test_rest_api_heartbeats.py
crashreports/tests/test_rest_api_logfiles.py
crashreports/tests/utils.py

index dbfaccc..b09ca1d 100644 (file)
@@ -1,8 +1,13 @@
 """Tests for the crashreports REST API."""
+import unittest
+from datetime import timedelta
+
+from django.db import connection
 from django.urls import reverse
 from rest_framework import status
 
-from crashreports.tests.utils import Dummy
+from crashreports.models import Crashreport
+from crashreports.tests.utils import Dummy, RaceConditionsTestCase
 from crashreports.tests.test_rest_api_heartbeats import HeartbeatsTestCase
 
 
@@ -47,3 +52,29 @@ class CrashreportsTestCase(HeartbeatsTestCase):
     def test_create_with_datetime(self):
         """Override to just pass because crashreports always use datetime."""
         pass
+
+
+@unittest.skip("Fails because of race condition when assigning local IDs")
+class CrashreportRaceConditionsTestCase(RaceConditionsTestCase):
+    """Test cases for crashreport race conditions."""
+
+    LIST_CREATE_URL = "api_v1_crashreports"
+
+    def test_create_multiple_crashreports(self):
+        """Test that no race condition occurs when creating crashreports."""
+        uuid, user, _ = self._register_device()
+
+        def upload_report(client, data):
+            response = client.post(reverse(self.LIST_CREATE_URL), data)
+            self.assertEqual(status.HTTP_201_CREATED, response.status_code)
+            connection.close()
+
+        data = Dummy.crashreport_data(uuid=uuid)
+        argslist = [
+            [user, dict(data, date=data["date"] + timedelta(milliseconds=i))]
+            for i in range(10)
+        ]
+
+        self._test_create_multiple(
+            Crashreport, upload_report, argslist, "device_local_id"
+        )
index 512a56a..bc4edea 100644 (file)
@@ -4,12 +4,14 @@ from django.urls import reverse
 
 from rest_framework import status
 
-from crashreports.tests.utils import HiccupCrashreportsAPITestCase, Dummy
+from crashreports.tests.utils import Dummy, HiccupCrashreportsAPITestCase
 
 
 class DeviceTestCase(HiccupCrashreportsAPITestCase):
     """Test cases for registering devices."""
 
+    # pylint: disable=too-many-ancestors
+
     def test_register(self):
         """Test registration of devices."""
         response = self.client.post(
@@ -79,6 +81,8 @@ class DeviceTestCase(HiccupCrashreportsAPITestCase):
 class ListDevicesTestCase(HiccupCrashreportsAPITestCase):
     """Test cases for listing and deleting devices."""
 
+    # pylint: disable=too-many-ancestors
+
     LIST_CREATE_URL = "api_v1_list_devices"
     RETRIEVE_URL = "api_v1_retrieve_device"
 
index 9ae6b29..d2c3e1a 100644 (file)
@@ -1,19 +1,26 @@
 """Tests for the heartbeats REST API."""
 from datetime import timedelta, datetime
+import unittest
 
 import pytz
+from django.db import connection
 from django.urls import reverse
 
 from rest_framework import status
 from rest_framework.test import APIClient
 
-from crashreports.tests.utils import HiccupCrashreportsAPITestCase, Dummy
+from crashreports.tests.utils import (
+    Dummy,
+    RaceConditionsTestCase,
+    HiccupCrashreportsAPITestCase,
+)
+from crashreports.models import HeartBeat
 
 
 class HeartbeatsTestCase(HiccupCrashreportsAPITestCase):
     """Test cases for heartbeats."""
 
-    # pylint: disable=too-many-public-methods
+    # pylint: disable=too-many-public-methods,too-many-ancestors
 
     LIST_CREATE_URL = "api_v1_heartbeats"
     RETRIEVE_URL = "api_v1_heartbeat"
@@ -313,3 +320,29 @@ class HeartbeatsTestCase(HiccupCrashreportsAPITestCase):
         response = self.user.post(reverse(self.LIST_CREATE_URL), data)
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertEqual(response.data["date"], str(data["date"].date()))
+
+
+@unittest.skip("Fails because of race condition when assigning local IDs")
+class HeartBeatRaceConditionsTestCase(RaceConditionsTestCase):
+    """Test cases for heartbeat race conditions."""
+
+    LIST_CREATE_URL = "api_v1_heartbeats"
+
+    def test_create_multiple_heartbeats(self):
+        """Test that no race condition occurs when creating heartbeats."""
+        uuid, user, _ = self._register_device()
+
+        def upload_report(client, data):
+            response = client.post(reverse(self.LIST_CREATE_URL), data)
+            self.assertEqual(status.HTTP_201_CREATED, response.status_code)
+            connection.close()
+
+        data = Dummy.heartbeat_data(uuid=uuid)
+        argslist = [
+            [user, dict(data, date=data["date"] + timedelta(days=i))]
+            for i in range(10)
+        ]
+
+        self._test_create_multiple(
+            HeartBeat, upload_report, argslist, "device_local_id"
+        )
index 57482bf..b0aeb34 100644 (file)
@@ -3,10 +3,12 @@
 import os
 import shutil
 import tempfile
+import unittest
 import zipfile
 
 from django.conf import settings
 from django.core.files.storage import default_storage
+from django.db import connection
 from django.test import override_settings
 from django.urls import reverse
 
@@ -18,13 +20,22 @@ from crashreports.models import (
     Crashreport,
     LogFile,
 )
-from crashreports.tests.utils import HiccupCrashreportsAPITestCase, Dummy
+from crashreports.tests.utils import (
+    Dummy,
+    RaceConditionsTestCase,
+    HiccupCrashreportsAPITestCase,
+)
+
+LIST_CREATE_URL = "api_v1_crashreports"
+PUT_LOGFILE_URL = "api_v1_putlogfile_for_device_id"
 
 
 @override_settings(MEDIA_ROOT=tempfile.mkdtemp(".hiccup-tests"))
 class LogfileUploadTest(HiccupCrashreportsAPITestCase):
     """Test cases for upload of log files."""
 
+    # pylint: disable=too-many-ancestors
+
     LIST_CREATE_URL = "api_v1_crashreports"
     PUT_LOGFILE_URL = "api_v1_putlogfile_for_device_id"
     POST_LOGFILE_URL = "api_v1_logfiles_by_id"
@@ -34,7 +45,7 @@ class LogfileUploadTest(HiccupCrashreportsAPITestCase):
         super().setUp()
         self.device_uuid, self.user, _ = self._register_device()
 
-    def _upload_crashreport(self, user, uuid):
+    def upload_crashreport(self, user, uuid):
         """
         Upload dummy crashreport data.
 
@@ -46,7 +57,7 @@ class LogfileUploadTest(HiccupCrashreportsAPITestCase):
 
         """
         data = Dummy.crashreport_data(uuid=uuid)
-        response = user.post(reverse(self.LIST_CREATE_URL), data)
+        response = user.post(reverse(LIST_CREATE_URL), data)
         self.assertEqual(status.HTTP_201_CREATED, response.status_code)
         self.assertTrue("device_local_id" in response.data)
         device_local_id = response.data["device_local_id"]
@@ -65,23 +76,27 @@ class LogfileUploadTest(HiccupCrashreportsAPITestCase):
 
             self.assertEqual(file_1.read(), file_2.read())
 
-    def _test_logfile_upload(self, user, uuid):
-        # Upload crashreport
-        device_local_id = self._upload_crashreport(user, uuid)
-
-        # Upload a logfile for the crashreport
+    def upload_logfile(self, client, uuid, device_local_id):
+        """Upload a log file and assert that it was created."""
         logfile = open(Dummy.DEFAULT_DUMMY_LOG_FILE_PATHS[0], "rb")
-
         logfile_name = os.path.basename(logfile.name)
-        response = user.post(
+        response = client.post(
             reverse(
-                self.PUT_LOGFILE_URL, args=[uuid, device_local_id, logfile_name]
+                PUT_LOGFILE_URL, args=[uuid, device_local_id, logfile_name]
             ),
             {"file": logfile},
             format="multipart",
         )
         logfile.close()
         self.assertEqual(status.HTTP_201_CREATED, response.status_code)
+        return response
+
+    def _test_logfile_upload(self, user, uuid):
+        # Upload crashreport
+        device_local_id = self.upload_crashreport(user, uuid)
+
+        # Upload a logfile for the crashreport
+        self.upload_logfile(user, uuid, device_local_id)
 
         logfile_instance = (
             Device.objects.get(uuid=uuid)
@@ -89,7 +104,8 @@ class LogfileUploadTest(HiccupCrashreportsAPITestCase):
             .logfiles.last()
         )
         uploaded_logfile_path = crashreport_file_name(
-            logfile_instance, logfile_name
+            logfile_instance,
+            os.path.basename(Dummy.DEFAULT_DUMMY_LOG_FILE_PATHS[0]),
         )
 
         self.assertTrue(default_storage.exists(uploaded_logfile_path))
@@ -135,3 +151,25 @@ class LogfileUploadTest(HiccupCrashreportsAPITestCase):
     def tearDown(self):
         """Remove the file and directories that were created for the test."""
         shutil.rmtree(settings.MEDIA_ROOT)
+
+
+@unittest.skip("Fails because of race condition when assigning local IDs")
+class LogfileRaceConditionsTestCase(RaceConditionsTestCase):
+    """Test cases for logfile race conditions."""
+
+    def test_create_multiple_logfiles(self):
+        """Test that no race condition occurs when creating logfiles."""
+        uuid, user, _ = self._register_device()
+        device_local_id = LogfileUploadTest.upload_crashreport(self, user, uuid)
+
+        def upload_logfile(client, uuid, device_local_id):
+            LogfileUploadTest.upload_logfile(
+                self, client, uuid, device_local_id
+            )
+            connection.close()
+
+        argslist = [[user, uuid, device_local_id] for _ in range(10)]
+
+        self._test_create_multiple(
+            LogFile, upload_logfile, argslist, "crashreport_local_id"
+        )
index 73a479b..e9139df 100644 (file)
@@ -2,6 +2,7 @@
 
 import os
 import shutil
+import threading
 import zipfile
 from datetime import date, datetime
 from typing import Optional
@@ -9,9 +10,10 @@ from typing import Optional
 import pytz
 from django.conf import settings
 from django.contrib.auth.models import User, Group
+from django.test import TransactionTestCase
 from django.urls import reverse
 from rest_framework import status
-from rest_framework.test import APITestCase, APIClient
+from rest_framework.test import APIClient, APITestCase
 
 from crashreports.models import (
     Crashreport,
@@ -370,7 +372,7 @@ class Dummy:
         return archive.read(logfile_name)
 
 
-class HiccupCrashreportsAPITestCase(APITestCase):
+class HiccupCrashreportsTransactionTestCase(TransactionTestCase):
     """Base class that offers a device registration method."""
 
     REGISTER_DEVICE_URL = "api_v1_register_device"
@@ -410,3 +412,40 @@ class HiccupCrashreportsAPITestCase(APITestCase):
         user.credentials(HTTP_AUTHORIZATION="Token " + token)
 
         return uuid, user, token
+
+
+class HiccupCrashreportsAPITestCase(
+    HiccupCrashreportsTransactionTestCase, APITestCase
+):
+    """Base class combining device registration methods and API test methods."""
+
+    pass
+
+
+class RaceConditionsTestCase(HiccupCrashreportsTransactionTestCase):
+    """Test cases for race conditions."""
+
+    # Make data from migrations available in the test cases
+    serialized_rollback = True
+
+    def _test_create_multiple(
+        self, report_type, create_function, argslist, local_id_name
+    ):
+        """Test that no race condition occurs when creating instances."""
+        # Create multiple threads which send reports simultaneously
+        threads = []
+        for args in argslist:
+            thread = threading.Thread(target=create_function, args=args)
+            threads.append(thread)
+            thread.start()
+
+        # Wait until the threads have finished
+        for thread in threads:
+            thread.join()
+
+        # Assert that no duplicate local IDs have been assigned
+        reports = report_type.objects.all()
+        self.assertEqual(
+            reports.count(), reports.distinct(local_id_name).count()
+        )
+        self.assertEqual(reports.count(), len(argslist))