From b934e181c04a1deb756102898c6bf43d9c3804ca Mon Sep 17 00:00:00 2001
From: Matthew Hague <matthew.hague@rhul.ac.uk>
Date: Sun, 28 Nov 2021 17:46:50 +0000
Subject: [PATCH] Breaking Change: specify call return types

The method call functions in CodeTester now expect an argument giving
the return type. Can use "null" if you don't care.
---
 .../uk/ac/rhul/cs/javatester/CodeTester.java  | 156 +++++++++---------
 .../uk/ac/rhul/cs/javatester/UnitTests.java   |  37 +++--
 2 files changed, 109 insertions(+), 84 deletions(-)

diff --git a/src/main/java/uk/ac/rhul/cs/javatester/CodeTester.java b/src/main/java/uk/ac/rhul/cs/javatester/CodeTester.java
index 1b81218..a0bfdc7 100644
--- a/src/main/java/uk/ac/rhul/cs/javatester/CodeTester.java
+++ b/src/main/java/uk/ac/rhul/cs/javatester/CodeTester.java
@@ -327,18 +327,14 @@ public class CodeTester {
     /**
      * Convenience for callMsgAcceptNull with acceptNull set to true
      */
-    public Object callMsg(String msg,
-                          String invocationMsg,
-                          Object o,
-                          String methodName,
-                          Object... params) throws
-                  BaseTester.FailedTestException {
-        return callMsgAcceptNull(msg,
-                                 invocationMsg,
-                                 true,
-                                 o,
-                                 methodName,
-                                 params);
+    public Object callMsg(
+        String msg, String invocationMsg,
+        Class<?> returnType, Object o, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
+        return callMsgAcceptNull(
+            msg, invocationMsg, true,
+            returnType, o, methodName, params
+        );
     }
 
     /**
@@ -354,25 +350,26 @@ public class CodeTester {
      * @param invocationMsg failure message prefix in case method call
      * throws exception (student code fails)
      * @param acceptNull whether the result of the call can be null
+     * @param returnType the expected return type of the method (or null
+     * if don't care)
      * @param o the object to call the method on
      * @param String name method name
      * @param parameterTypes the arguments (cannot be null else type
      * can't be inferred, with throw NPE)
      * @return the result object
      */
-    public Object callMsgAcceptNull(String msg,
-                                    String invocationMsg,
-                                    boolean acceptNull,
-                                    Object o,
-                                    String methodName,
-                                    Object... params) throws
-                  BaseTester.FailedTestException {
+    public Object callMsgAcceptNull(
+        String msg, String invocationMsg, boolean acceptNull,
+        Class<?> returnType, Object o, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
         Class<?>[] parameterTypes
             = Stream.of(params).map(p -> p.getClass())
                                .toArray(Class<?>[]::new);
         try {
             Method m = invasiveGetMethod(o, methodName, parameterTypes);
 
+            checkReturnType(invocationMsg, m, returnType);
+
             Object r = invokeMethod(invocationMsg, m, o, params);
 
             if (r == null && !acceptNull) {
@@ -400,27 +397,24 @@ public class CodeTester {
         throw new BaseTester.FailedTestException(expandMsg(msg));
     }
 
-
     /**
      * Convenience for callMsg with default messages.
      *
      * Will allow method to return null.
      */
-    public Object call(Object o,
-                       String methodName,
-                       Object... params) throws
-                  BaseTester.FailedTestException {
-        return callAcceptNull(true, o, methodName, params);
+    public Object call(
+        Class<?> returnType, Object o, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
+        return callAcceptNull(true, returnType, o, methodName, params);
     }
 
     /**
      * Convenience for callMsgAcceptNull with default messages.
      */
-    public Object callAcceptNull(boolean acceptNull,
-                                 Object o,
-                                 String methodName,
-                                 Object... params) throws
-                  BaseTester.FailedTestException {
+    public Object callAcceptNull(
+        boolean acceptNull,
+        Class<?> returnType, Object o, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
         String base;
 
         if (getLines().size() > 0) {
@@ -435,12 +429,10 @@ public class CodeTester {
         String msg = base + "could not be made.";
         String invokeMsg = base + "went wrong.";
 
-        return callMsgAcceptNull(msg,
-                                 invokeMsg,
-                                 acceptNull,
-                                 o,
-                                 methodName,
-                                 params);
+        return callMsgAcceptNull(
+            msg, invokeMsg, acceptNull,
+            returnType, o, methodName, params
+        );
     }
 
     /**
@@ -500,18 +492,14 @@ public class CodeTester {
     /**
      * Convenience for callStaticMsgAcceptNull with acceptNull as true
      */
-    public Object callStaticMsg(String msg,
-                                String invocationMsg,
-                                Class<?> klass,
-                                String methodName,
-                                Object... params) throws
-                  BaseTester.FailedTestException {
-        return callStaticMsgAcceptNull(msg,
-                                       invocationMsg,
-                                       true,
-                                       klass,
-                                       methodName,
-                                       params);
+    public Object callStaticMsg(
+        String msg, String invocationMsg,
+        Class<?> returnType, Class<?> klass, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
+        return callStaticMsgAcceptNull(
+            msg, invocationMsg, true,
+            returnType, klass, methodName, params
+        );
     }
 
     /**
@@ -527,18 +515,17 @@ public class CodeTester {
      * @param invocationMsg failure message prefix in case method call
      * throws exception (student code fails)
      * @param acceptNull whether null is an ok return value
+     * @param returnType the expected return type of the method (or
+     * null if don't care)
      * @param klass the class having the method
      * @param String name method name
      * @param parameterTypes the arguments
      * @return the result object
      */
-    public Object callStaticMsgAcceptNull(String msg,
-                                          String invocationMsg,
-                                          boolean acceptNull,
-                                          Class<?> klass,
-                                          String methodName,
-                                          Object... params) throws
-                  BaseTester.FailedTestException {
+    public Object callStaticMsgAcceptNull(
+        String msg, String invocationMsg, boolean acceptNull,
+        Class<?> returnType, Class<?> klass, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
         Class<?>[] parameterTypes
             = Stream.of(params).map(p -> p.getClass())
                                .toArray(Class<?>[]::new);
@@ -552,6 +539,8 @@ public class CodeTester {
                 );
             }
 
+            checkReturnType(msg, m, returnType);
+
             Object r = invokeMethod(invocationMsg, m, null, params);
 
             if (r == null && !acceptNull) {
@@ -587,24 +576,22 @@ public class CodeTester {
     /**
      * Convenience for callMsgAcceptNull with acceptNull as true
      */
-    public Object callStatic(Class<?> klass,
-                             String methodName,
-                             Object... params) throws
-                  BaseTester.FailedTestException {
-        return callStaticAcceptNull(true,
-                                    klass,
-                                    methodName,
-                                    params);
+    public Object callStatic(
+        Class<?> returnType, Class<?> klass, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
+        return callStaticAcceptNull(
+            true,
+            returnType, klass, methodName, params
+        );
     }
 
     /**
      * Convenience for callStaticMsgAcceptNull with default messages.
      */
-    public Object callStaticAcceptNull(boolean acceptNull,
-                                       Class<?> klass,
-                                       String methodName,
-                                       Object... params) throws
-                  BaseTester.FailedTestException {
+    public Object callStaticAcceptNull(
+        boolean acceptNull,
+        Class<?> returnType, Class<?> klass, String methodName, Object... params
+    ) throws BaseTester.FailedTestException {
         String base;
 
         if (getLines().size() > 0) {
@@ -619,12 +606,10 @@ public class CodeTester {
         String msg = base + "could not be made.";
         String invokeMsg = base + "went wrong.";
 
-        return callStaticMsgAcceptNull(msg,
-                                       invokeMsg,
-                                       acceptNull,
-                                       klass,
-                                       methodName,
-                                       params);
+        return callStaticMsgAcceptNull(
+            msg, invokeMsg, acceptNull,
+            returnType, klass, methodName, params
+        );
     }
 
     /**
@@ -1167,6 +1152,29 @@ public class CodeTester {
             return logIdealizer.apply(output);
     }
 
+    /**
+     * Check the method returns something compatible with returnType
+     *
+     * @param invocationMsg the message to expand in case of failure
+     * @param method the method to check
+     * @param returnType the expected return type
+     * @throws FailedTestException if no match
+     */
+    private void checkReturnType(
+        String invocationMsg, Method method, Class<?> returnType
+    ) throws BaseTester.FailedTestException {
+        if (returnType != null) {
+            if (!isAssignable(method.getReturnType(), returnType)) {
+                throw new BaseTester.FailedTestException(
+                    expandMsg(invocationMsg) + "\n\n" +
+                    "Expected the method to have a return type compatible with " +
+                    getClassString(returnType) +
+                    "."
+                );
+            }
+        }
+    }
+
     /**
      * A class to represent a null argument which can be any object type
      */
diff --git a/src/test/java/uk/ac/rhul/cs/javatester/UnitTests.java b/src/test/java/uk/ac/rhul/cs/javatester/UnitTests.java
index 66db034..41b2065 100644
--- a/src/test/java/uk/ac/rhul/cs/javatester/UnitTests.java
+++ b/src/test/java/uk/ac/rhul/cs/javatester/UnitTests.java
@@ -225,8 +225,9 @@ public class UnitTests {
                     throws BaseTester.FailedTestException {
                 CodeTester ct = new CodeTester();
                 Object node = ct.construct(STORAGE, "value");
-                ct.call(node, "setValue", "new val");
-                return ct.call(node, "getValue").equals("new val");
+                ct.call(null, node, "setValue", "new val");
+                return ct.call(String.class, node, "getValue")
+                        .equals("new val");
             }
         };
         assertTrue(tester.runTests());
@@ -240,13 +241,29 @@ public class UnitTests {
                     throws BaseTester.FailedTestException {
                 CodeTester ct = new CodeTester();
                 Object node = ct.construct(STORAGE, "value");
-                ct.call(node, "doesNotExist", "new val");
+                ct.call(null, node, "doesNotExist", "new val");
                 return true;
             }
         };
         assertFalse(tester.runTests());
     }
 
+    @Test
+    public void testCallMethodWrongReturn() {
+        BaseTester tester = new BaseTester(STORAGE) {
+            @SuppressWarnings("unused")
+            public boolean testCallMethod()
+                    throws BaseTester.FailedTestException {
+                CodeTester ct = new CodeTester();
+                Object node = ct.construct(STORAGE, "value");
+                ct.call(int.class, node, "getValue");
+                return true;
+            }
+        };
+        assertFalse(tester.runTests());
+    }
+
+
     @Test
     public void testOptInOutHasNot() {
         BaseTester tester = new BaseTester(SIMPLE_IN_OUT) {
@@ -295,10 +312,10 @@ public class UnitTests {
                 CodeTester ct = new CodeTester();
                 Object priv = ct.construct(PRIVATE_CLASS);
                 ct.construct(PRIVATE_CLASS, "isPrivate");
-                ct.call(priv, "privateMethod");
-                ct.call(priv, "protectedMethod");
-                ct.call(priv, "packageMethod");
-                ct.call(priv, "publicMethod");
+                ct.call(null, priv, "privateMethod");
+                ct.call(null, priv, "protectedMethod");
+                ct.call(null, priv, "packageMethod");
+                ct.call(null, priv, "publicMethod");
                 return true;
             }
         };
@@ -312,7 +329,7 @@ public class UnitTests {
             public boolean testNullPtr() throws FailedTestException {
                 CodeTester ct = new CodeTester();
                 Object np = ct.construct(NULL_PTR_CLASS);
-                ct.call(np, "nullPtr");
+                ct.call(null, np, "nullPtr");
                 return true;
             }
         };
@@ -340,7 +357,7 @@ public class UnitTests {
                     throws BaseTester.FailedTestException {
                 CodeTester ct = new CodeTester();
                 Object o = ct.construct(NULL_PTR_CLASS);
-                ct.callAcceptNull(false, o, "getNull");
+                ct.callAcceptNull(false, String.class, o, "getNull");
                 return true;
             }
         };
@@ -355,7 +372,7 @@ public class UnitTests {
                     throws BaseTester.FailedTestException {
                 CodeTester ct = new CodeTester();
                 Class<?> klass = ct.loadClass(NULL_PTR_CLASS);
-                ct.callStaticAcceptNull(false, klass, "getNull");
+                ct.callStaticAcceptNull(false, String.class, klass, "getNull");
                 return true;
             }
         };
-- 
GitLab