diff --git a/Libraries/WebSocket/RCTSRWebSocket.h b/Libraries/WebSocket/RCTSRWebSocket.h index 127530434..1b17cffaf 100644 --- a/Libraries/WebSocket/RCTSRWebSocket.h +++ b/Libraries/WebSocket/RCTSRWebSocket.h @@ -60,19 +60,17 @@ extern NSString *const RCTSRHTTPResponseErrorKey; @property (nonatomic, readonly, copy) NSString *protocol; // Protocols should be an array of strings that turn into Sec-WebSocket-Protocol. -// options can contain a custom "origin" NSString -- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols options:(NSDictionary *)options NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols NS_DESIGNATED_INITIALIZER; - (instancetype)initWithURLRequest:(NSURLRequest *)request; // Some helper constructors. -- (instancetype)initWithURL:(NSURL *)url protocols:(NSArray *)protocols options:(NSDictionary *)options; - (instancetype)initWithURL:(NSURL *)url protocols:(NSArray *)protocols; - (instancetype)initWithURL:(NSURL *)url; // Delegate queue will be dispatch_main_queue by default. // You cannot set both OperationQueue and dispatch_queue. -- (void)setDelegateOperationQueue:(NSOperationQueue*) queue; -- (void)setDelegateDispatchQueue:(dispatch_queue_t) queue; +- (void)setDelegateOperationQueue:(NSOperationQueue *)queue; +- (void)setDelegateDispatchQueue:(dispatch_queue_t)queue; // By default, it will schedule itself on +[NSRunLoop RCTSR_networkRunLoop] using defaultModes. - (void)scheduleInRunLoop:(NSRunLoop *)aRunLoop forMode:(NSString *)mode; diff --git a/Libraries/WebSocket/RCTSRWebSocket.m b/Libraries/WebSocket/RCTSRWebSocket.m index bf67c3b14..fe02d0697 100644 --- a/Libraries/WebSocket/RCTSRWebSocket.m +++ b/Libraries/WebSocket/RCTSRWebSocket.m @@ -234,7 +234,6 @@ typedef void (^data_callback)(RCTSRWebSocket *webSocket, NSData *data); __strong RCTSRWebSocket *_selfRetain; NSArray *_requestedProtocols; - NSDictionary *_requestedOptions; RCTSRIOConsumerPool *_consumerPool; } @@ -245,7 +244,7 @@ static __strong NSData *CRLFCRLF; CRLFCRLF = [[NSData alloc] initWithBytes:"\r\n\r\n" length:4]; } -- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols options:(NSDictionary *)options +- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols { RCTAssertParam(request); @@ -254,7 +253,6 @@ static __strong NSData *CRLFCRLF; _urlRequest = request; _requestedProtocols = [protocols copy]; - _requestedOptions = [options copy]; [self _RCTSR_commonInit]; } @@ -265,20 +263,15 @@ RCT_NOT_IMPLEMENTED(- (instancetype)init) - (instancetype)initWithURLRequest:(NSURLRequest *)request; { - return [self initWithURLRequest:request protocols:nil options: nil]; + return [self initWithURLRequest:request protocols:nil]; } - (instancetype)initWithURL:(NSURL *)URL; { - return [self initWithURL:URL protocols:nil options:nil]; + return [self initWithURL:URL protocols:nil]; } - (instancetype)initWithURL:(NSURL *)URL protocols:(NSArray *)protocols; -{ - return [self initWithURL:URL protocols:protocols options:nil]; -} - -- (instancetype)initWithURL:(NSURL *)URL protocols:(NSArray *)protocols options:(NSDictionary *)options { NSMutableURLRequest *request; if (URL) { @@ -297,7 +290,7 @@ RCT_NOT_IMPLEMENTED(- (instancetype)init) NSArray *cookies = [[NSHTTPCookieStorage sharedHTTPCookieStorage] cookiesForURL:components.URL]; [request setAllHTTPHeaderFields:[NSHTTPCookie requestHeaderFieldsWithCookies:cookies]]; } - return [self initWithURLRequest:request protocols:protocols options:options]; + return [self initWithURLRequest:request protocols:protocols]; } - (void)_RCTSR_commonInit; @@ -488,12 +481,12 @@ RCT_NOT_IMPLEMENTED(- (instancetype)init) CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Sec-WebSocket-Key"), (__bridge CFStringRef)_secKey); CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Sec-WebSocket-Version"), (__bridge CFStringRef)[NSString stringWithFormat:@"%ld", (long)_webSocketVersion]); + CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Origin"), (__bridge CFStringRef)_url.RCTSR_origin); + if (_requestedProtocols) { CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Sec-WebSocket-Protocol"), (__bridge CFStringRef)[_requestedProtocols componentsJoinedByString:@", "]); } - CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Origin"), (__bridge CFStringRef)(_requestedOptions[@"origin"] ?: _url.RCTSR_origin)); - [_urlRequest.allHTTPHeaderFields enumerateKeysAndObjectsUsingBlock:^(id key, id obj, BOOL *stop) { CFHTTPMessageSetHeaderFieldValue(request, (__bridge CFStringRef)key, (__bridge CFStringRef)obj); }]; diff --git a/Libraries/WebSocket/RCTWebSocketModule.m b/Libraries/WebSocket/RCTWebSocketModule.m index 6e057f78b..aa5d0e412 100644 --- a/Libraries/WebSocket/RCTWebSocketModule.m +++ b/Libraries/WebSocket/RCTWebSocketModule.m @@ -11,6 +11,7 @@ #import "RCTBridge.h" #import "RCTEventDispatcher.h" +#import "RCTConvert.h" #import "RCTUtils.h" @implementation RCTSRWebSocket (React) @@ -44,9 +45,14 @@ RCT_EXPORT_MODULE() } } -RCT_EXPORT_METHOD(connect:(NSURL *)URL protocols:(NSArray *)protocols options:(NSDictionary *)options socketID:(nonnull NSNumber *)socketID) +RCT_EXPORT_METHOD(connect:(NSURL *)URL protocols:(NSArray *)protocols headers:(NSDictionary *)headers socketID:(nonnull NSNumber *)socketID) { - RCTSRWebSocket *webSocket = [[RCTSRWebSocket alloc] initWithURL:URL protocols:protocols options:options]; + NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:URL]; + [headers enumerateKeysAndObjectsUsingBlock:^(NSString *key, id value, BOOL *stop) { + [request addValue:[RCTConvert NSString:value] forHTTPHeaderField:key]; + }]; + + RCTSRWebSocket *webSocket = [[RCTSRWebSocket alloc] initWithURLRequest:request protocols:protocols]; webSocket.delegate = self; webSocket.reactTag = socketID; if (!_sockets) { diff --git a/Libraries/WebSocket/WebSocket.js b/Libraries/WebSocket/WebSocket.js index ec6f722ef..905a9f655 100644 --- a/Libraries/WebSocket/WebSocket.js +++ b/Libraries/WebSocket/WebSocket.js @@ -33,10 +33,10 @@ class WebSocket extends WebSocketBase { _socketId: number; _subs: any; - connectToSocketImpl(url: string, protocols: ?Array, options: ?{origin?: string}): void { + connectToSocketImpl(url: string, protocols: ?Array, headers: ?Object): void { this._socketId = WebSocketId++; - RCTWebSocketModule.connect(url, protocols, options, this._socketId); + RCTWebSocketModule.connect(url, protocols, headers, this._socketId); this._registerEvents(this._socketId); } diff --git a/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java b/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java index af8d79597..1d9d1dda1 100644 --- a/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java +++ b/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java @@ -34,6 +34,8 @@ import com.squareup.okhttp.ws.WebSocket; import com.squareup.okhttp.ws.WebSocketCall; import com.squareup.okhttp.ws.WebSocketListener; +import java.net.URISyntaxException; +import java.net.URI; import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -63,7 +65,7 @@ public class WebSocketModule extends ReactContextBaseJavaModule { } @ReactMethod - public void connect(final String url, @Nullable final ReadableArray protocols, @Nullable final ReadableMap options, final int id) { + public void connect(final String url, @Nullable final ReadableArray protocols, @Nullable final ReadableMap headers, final int id) { // ignoring protocols, since OKHttp overrides them. OkHttpClient client = new OkHttpClient(); @@ -76,14 +78,25 @@ public class WebSocketModule extends ReactContextBaseJavaModule { .tag(id) .url(url); - if (options != null && options.hasKey("origin")) { - if (ReadableType.String.equals(options.getType("origin"))) { - builder.addHeader("Origin", options.getString("origin")); - } else { - FLog.w( - ReactConstants.TAG, - "Ignoring: requested origin, value not a string"); + if (headers != null) { + ReadableMapKeySetIterator iterator = headers.keySetIterator(); + + if (!headers.hasKey("origin")) { + builder.addHeader("origin", setDefaultOrigin(url)); } + + while (iterator.hasNextKey()) { + String key = iterator.nextKey(); + if (ReadableType.String.equals(headers.getType(key))) { + builder.addHeader(key, headers.getString(key)); + } else { + FLog.w( + ReactConstants.TAG, + "Ignoring: requested " + key + ", value not a string"); + } + } + } else { + builder.addHeader("origin", setDefaultOrigin(url)); } WebSocketCall.create(client, builder.build()).enqueue(new WebSocketListener() { @@ -188,4 +201,37 @@ public class WebSocketModule extends ReactContextBaseJavaModule { params.putString("message", message); sendEvent("websocketFailed", params); } + + /** + * Set a default origin + * + * @param Websocket connection endpoint + * @return A string of the endpoint converted to HTTP protocol + */ + + private static String setDefaultOrigin(String uri) { + try { + String defaultOrigin; + String scheme = ""; + + URI requestURI = new URI(uri); + if (requestURI.getScheme().equals("wss")) { + scheme += "https"; + } else if (requestURI.getScheme().equals("ws")) { + scheme += "http"; + } + + if (requestURI.getPort() != -1) { + defaultOrigin = String.format("%s://%s:%s", scheme, requestURI.getHost(), requestURI.getPort()); + } else { + defaultOrigin = String.format("%s://%s/", scheme, requestURI.getHost()); + } + + return defaultOrigin; + + } catch(URISyntaxException e) { + throw new IllegalArgumentException("Unable to set " + uri + " as default origin header."); + } + } + }