Correctly cancel network calls in catalyst instance destroy

Reviewed By: bestander

Differential Revision: D3358458

fbshipit-source-id: 69ee83ba33b02d21310030a457c6378f67e168f9
This commit is contained in:
Andrei Coman 2016-05-30 05:43:45 -07:00 committed by Facebook Github Bot 2
parent 5571794035
commit 7914d3334d
3 changed files with 145 additions and 21 deletions

View File

@ -32,8 +32,4 @@ public class OkHttpCallUtil {
} }
} }
} }
public static void cancelAll(OkHttpClient client) {
client.dispatcher().cancelAll();
}
} }

View File

@ -9,6 +9,17 @@
package com.facebook.react.modules.network; package com.facebook.react.modules.network;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.net.SocketTimeoutException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import com.facebook.react.bridge.Arguments; import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.ExecutorToken; import com.facebook.react.bridge.ExecutorToken;
import com.facebook.react.bridge.GuardedAsyncTask; import com.facebook.react.bridge.GuardedAsyncTask;
@ -22,15 +33,6 @@ import com.facebook.react.bridge.WritableMap;
import com.facebook.react.common.network.OkHttpCallUtil; import com.facebook.react.common.network.OkHttpCallUtil;
import com.facebook.react.modules.core.DeviceEventManagerModule; import com.facebook.react.modules.core.DeviceEventManagerModule;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.net.SocketTimeoutException;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import okhttp3.Call; import okhttp3.Call;
import okhttp3.Callback; import okhttp3.Callback;
import okhttp3.Headers; import okhttp3.Headers;
@ -61,6 +63,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
private final ForwardingCookieHandler mCookieHandler; private final ForwardingCookieHandler mCookieHandler;
private final @Nullable String mDefaultUserAgent; private final @Nullable String mDefaultUserAgent;
private final CookieJarContainer mCookieJarContainer; private final CookieJarContainer mCookieJarContainer;
private final Set<Integer> mRequestIds;
private boolean mShuttingDown; private boolean mShuttingDown;
/* package */ NetworkingModule( /* package */ NetworkingModule(
@ -69,7 +72,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
OkHttpClient client, OkHttpClient client,
@Nullable List<NetworkInterceptorCreator> networkInterceptorCreators) { @Nullable List<NetworkInterceptorCreator> networkInterceptorCreators) {
super(reactContext); super(reactContext);
if (networkInterceptorCreators != null) { if (networkInterceptorCreators != null) {
OkHttpClient.Builder clientBuilder = client.newBuilder(); OkHttpClient.Builder clientBuilder = client.newBuilder();
for (NetworkInterceptorCreator networkInterceptorCreator : networkInterceptorCreators) { for (NetworkInterceptorCreator networkInterceptorCreator : networkInterceptorCreators) {
@ -83,6 +86,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
mCookieJarContainer = (CookieJarContainer) mClient.cookieJar(); mCookieJarContainer = (CookieJarContainer) mClient.cookieJar();
mShuttingDown = false; mShuttingDown = false;
mDefaultUserAgent = defaultUserAgent; mDefaultUserAgent = defaultUserAgent;
mRequestIds = new HashSet<>();
} }
/** /**
@ -138,7 +142,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
@Override @Override
public void onCatalystInstanceDestroy() { public void onCatalystInstanceDestroy() {
mShuttingDown = true; mShuttingDown = true;
OkHttpCallUtil.cancelAll(mClient); cancelAllRequests();
mCookieHandler.destroy(); mCookieHandler.destroy();
mCookieJarContainer.removeCookieJar(); mCookieJarContainer.removeCookieJar();
@ -241,6 +245,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
requestBuilder.method(method, RequestBodyUtil.getEmptyBody(method)); requestBuilder.method(method, RequestBodyUtil.getEmptyBody(method));
} }
addRequest(requestId);
client.newCall(requestBuilder.build()).enqueue( client.newCall(requestBuilder.build()).enqueue(
new Callback() { new Callback() {
@Override @Override
@ -248,6 +253,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
if (mShuttingDown) { if (mShuttingDown) {
return; return;
} }
removeRequest(requestId);
onRequestError(executorToken, requestId, e.getMessage(), e); onRequestError(executorToken, requestId, e.getMessage(), e);
} }
@ -256,7 +262,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
if (mShuttingDown) { if (mShuttingDown) {
return; return;
} }
removeRequest(requestId);
// Before we touch the body send headers to JS // Before we touch the body send headers to JS
onResponseReceived(executorToken, requestId, response); onResponseReceived(executorToken, requestId, response);
@ -335,6 +341,21 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
getEventEmitter(ExecutorToken).emit("didReceiveNetworkResponse", args); getEventEmitter(ExecutorToken).emit("didReceiveNetworkResponse", args);
} }
private synchronized void addRequest(int requestId) {
mRequestIds.add(requestId);
}
private synchronized void removeRequest(int requestId) {
mRequestIds.remove(requestId);
}
private synchronized void cancelAllRequests() {
for (Integer requestId : mRequestIds) {
cancelRequest(requestId);
}
mRequestIds.clear();
}
private static WritableMap translateHeaders(Headers headers) { private static WritableMap translateHeaders(Headers headers) {
WritableMap responseHeaders = Arguments.createMap(); WritableMap responseHeaders = Arguments.createMap();
for (int i = 0; i < headers.size(); i++) { for (int i = 0; i < headers.size(); i++) {
@ -353,6 +374,11 @@ public final class NetworkingModule extends ReactContextBaseJavaModule {
@ReactMethod @ReactMethod
public void abortRequest(ExecutorToken executorToken, final int requestId) { public void abortRequest(ExecutorToken executorToken, final int requestId) {
cancelRequest(requestId);
removeRequest(requestId);
}
private void cancelRequest(final int requestId) {
// We have to use AsyncTask since this might trigger a NetworkOnMainThreadException, this is an // We have to use AsyncTask since this might trigger a NetworkOnMainThreadException, this is an
// open issue on OkHttp: https://github.com/square/okhttp/issues/869 // open issue on OkHttp: https://github.com/square/okhttp/issues/869
new GuardedAsyncTask<Void, Void>(getReactApplicationContext()) { new GuardedAsyncTask<Void, Void>(getReactApplicationContext()) {

View File

@ -15,12 +15,13 @@ import java.util.List;
import com.facebook.react.bridge.Arguments; import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.ExecutorToken; import com.facebook.react.bridge.ExecutorToken;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContext;
import com.facebook.react.bridge.JavaOnlyArray; import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap; import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContext;
import com.facebook.react.bridge.WritableArray; import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap; import com.facebook.react.bridge.WritableMap;
import com.facebook.react.common.network.OkHttpCallUtil;
import com.facebook.react.modules.core.DeviceEventManagerModule.RCTDeviceEventEmitter; import com.facebook.react.modules.core.DeviceEventManagerModule.RCTDeviceEventEmitter;
import okhttp3.Call; import okhttp3.Call;
@ -31,7 +32,6 @@ import okhttp3.OkHttpClient;
import okhttp3.Request; import okhttp3.Request;
import okhttp3.RequestBody; import okhttp3.RequestBody;
import okio.Buffer; import okio.Buffer;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -40,8 +40,8 @@ import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito; import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.rule.PowerMockRule; import org.powermock.modules.junit4.rule.PowerMockRule;
import org.robolectric.RobolectricTestRunner; import org.robolectric.RobolectricTestRunner;
@ -63,7 +63,8 @@ import static org.mockito.Mockito.when;
MultipartBody.class, MultipartBody.class,
MultipartBody.Builder.class, MultipartBody.Builder.class,
NetworkingModule.class, NetworkingModule.class,
OkHttpClient.class}) OkHttpClient.class,
OkHttpCallUtil.class})
@RunWith(RobolectricTestRunner.class) @RunWith(RobolectricTestRunner.class)
@PowerMockIgnore({"org.mockito.*", "org.robolectric.*", "android.*"}) @PowerMockIgnore({"org.mockito.*", "org.robolectric.*", "android.*"})
public class NetworkingModuleTest { public class NetworkingModuleTest {
@ -476,4 +477,105 @@ public class NetworkingModuleTest {
assertThat(bodyRequestBody.get(1).contentType()).isEqualTo(MediaType.parse("image/jpg")); assertThat(bodyRequestBody.get(1).contentType()).isEqualTo(MediaType.parse("image/jpg"));
assertThat(bodyRequestBody.get(1).contentLength()).isEqualTo("imageUri".getBytes().length); assertThat(bodyRequestBody.get(1).contentLength()).isEqualTo("imageUri".getBytes().length);
} }
@Test
public void testCancelAllCallsOnCatalystInstanceDestroy() throws Exception {
PowerMockito.mockStatic(OkHttpCallUtil.class);
OkHttpClient httpClient = mock(OkHttpClient.class);
final int requests = 3;
final Call[] calls = new Call[requests];
for (int idx = 0; idx < requests; idx++) {
calls[idx] = mock(Call.class);
}
when(httpClient.cookieJar()).thenReturn(mock(CookieJarContainer.class));
when(httpClient.newCall(any(Request.class))).thenAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
Request request = (Request) invocation.getArguments()[0];
return calls[(Integer) request.tag() - 1];
}
});
NetworkingModule networkingModule = new NetworkingModule(null, "", httpClient);
networkingModule.initialize();
for (int idx = 0; idx < requests; idx++) {
networkingModule.sendRequest(
mock(ExecutorToken.class),
"GET",
"http://somedomain/foo",
idx + 1,
JavaOnlyArray.of(),
null,
true,
0);
}
verify(httpClient, times(3)).newCall(any(Request.class));
networkingModule.onCatalystInstanceDestroy();
PowerMockito.verifyStatic(times(3));
ArgumentCaptor<OkHttpClient> clientArguments = ArgumentCaptor.forClass(OkHttpClient.class);
ArgumentCaptor<Integer> requestIdArguments = ArgumentCaptor.forClass(Integer.class);
OkHttpCallUtil.cancelTag(clientArguments.capture(), requestIdArguments.capture());
assertThat(requestIdArguments.getAllValues().size()).isEqualTo(requests);
for (int idx = 0; idx < requests; idx++) {
assertThat(requestIdArguments.getAllValues().contains(idx + 1)).isTrue();
}
}
@Test
public void testCancelSomeCallsOnCatalystInstanceDestroy() throws Exception {
PowerMockito.mockStatic(OkHttpCallUtil.class);
OkHttpClient httpClient = mock(OkHttpClient.class);
final int requests = 3;
final Call[] calls = new Call[requests];
for (int idx = 0; idx < requests; idx++) {
calls[idx] = mock(Call.class);
}
when(httpClient.cookieJar()).thenReturn(mock(CookieJarContainer.class));
when(httpClient.newCall(any(Request.class))).thenAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
Request request = (Request) invocation.getArguments()[0];
return calls[(Integer) request.tag() - 1];
}
});
NetworkingModule networkingModule = new NetworkingModule(null, "", httpClient);
for (int idx = 0; idx < requests; idx++) {
networkingModule.sendRequest(
mock(ExecutorToken.class),
"GET",
"http://somedomain/foo",
idx + 1,
JavaOnlyArray.of(),
null,
true,
0);
}
verify(httpClient, times(3)).newCall(any(Request.class));
networkingModule.abortRequest(mock(ExecutorToken.class), requests);
PowerMockito.verifyStatic(times(1));
ArgumentCaptor<OkHttpClient> clientArguments = ArgumentCaptor.forClass(OkHttpClient.class);
ArgumentCaptor<Integer> requestIdArguments = ArgumentCaptor.forClass(Integer.class);
OkHttpCallUtil.cancelTag(clientArguments.capture(), requestIdArguments.capture());
assertThat(requestIdArguments.getAllValues().size()).isEqualTo(1);
assertThat(requestIdArguments.getAllValues().get(0)).isEqualTo(requests);
// verifyStatic actually does not clear all calls so far, so we have to check for all of them.
// If `cancelTag` would've been called again for the aborted call, we would have had
// `requests + 1` calls.
networkingModule.onCatalystInstanceDestroy();
PowerMockito.verifyStatic(times(requests));
clientArguments = ArgumentCaptor.forClass(OkHttpClient.class);
requestIdArguments = ArgumentCaptor.forClass(Integer.class);
OkHttpCallUtil.cancelTag(clientArguments.capture(), requestIdArguments.capture());
assertThat(requestIdArguments.getAllValues().size()).isEqualTo(requests);
for (int idx = 0; idx < requests; idx++) {
assertThat(requestIdArguments.getAllValues().contains(idx + 1)).isTrue();
}
}
} }