// Copyright 1997-2001 Omni Development, Inc.  All rights reserved.
//
// This software may only be used and reproduced according to the
// terms in the file OmniSourceLicense.html, which should be
// distributed with this project and can also be found at
// http://www.omnigroup.com/DeveloperResources/OmniSourceLicense.html.

#import "ONInternetSocket.h"

#import <Foundation/Foundation.h>
#import <OmniBase/OmniBase.h>
#import <OmniBase/system.h>

#import "ONInternetSocket-Private.h"
#import "ONServiceEntry.h"
#import "ONHostAddress.h"
#import "ONHost.h"
#import "ONInterface.h"

RCS_ID("$Header: /NetworkDisk/Source/CVS/OmniGroup/Frameworks/OmniNetworking/ONInternetSocket.m,v 1.32.4.1 2001/06/23 00:01:26 kc Exp $")

@interface ONInternetSocket (Private)
+ (int)createSocketFD;
@end

@implementation ONInternetSocket

+ (int)protocolFamily;
{
    return PF_INET;
}

+ (int)ipProtocol;
{
    OBRequestConcreteImplementation(self, _cmd);
    return NSNotFound; // Not executed
}

+ (int)socketType;
{
    OBRequestConcreteImplementation(self, _cmd);
    return NSNotFound; // Not executed
}

+ (ONInternetSocket *)socket;
{
    return [[[self alloc] _initWithSocketFD:[self createSocketFD] connected:NO] autorelease];
}

// Init and dealloc

- (void)dealloc;
{
    if (socketFD != -1) {
	BOOL socketCloseSucceeded;

	socketCloseSucceeded = OBSocketClose(socketFD) == 0;
	OBASSERT(socketCloseSucceeded);
        socketFD = -1;
    }
	
    if (localAddress)
        NSZoneFree(NULL, localAddress);
    if (remoteAddress)
        NSZoneFree(NULL, remoteAddress);

    [super dealloc];
}

//

- (void)setLocalPortNumber;
{
    // Bind to any available local port
    [self setLocalPortNumber: 0];
}

- (void)setLocalPortNumber:(unsigned short)port;
{
    struct sockaddr_in socketAddress;

    socketAddress.sin_family      = AF_INET;
    socketAddress.sin_addr.s_addr = htonl(INADDR_ANY);
    socketAddress.sin_port        = htons(port);

    if (bind(socketFD, (struct sockaddr *)&socketAddress, sizeof(socketAddress)) == -1)
	[NSException raise:ONInternetSocketBindFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Unable to bind a socket: %s", strerror(OMNI_ERRNO())];
}

- (void)setLocalPortNumber:(unsigned short int)port allowingAddressReuse:(BOOL)reuse;
{
    // Convert BOOL to an int
    if (reuse) {
        int shouldReuse;

        shouldReuse = 1;
        if (setsockopt(socketFD, SOL_SOCKET, SO_REUSEADDR, &shouldReuse, sizeof(shouldReuse)) == -1)
            [NSException raise:ONInternetSocketReuseSelectionFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Failed to set address reuse on socket: %s", strerror(OMNI_ERRNO())];
        shouldReuse = 1;
        if (setsockopt(socketFD, SOL_SOCKET, SO_REUSEPORT, &shouldReuse, sizeof(shouldReuse)) == -1)
            [NSException raise:ONInternetSocketReuseSelectionFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Failed to set address reuse on socket: %s", strerror(OMNI_ERRNO())];
    }

    [self setLocalPortNumber:port];
}


- (const struct sockaddr_in *)localAddress;
{
    int addressLength;
    struct sockaddr *address;

    if (localAddress)
        return localAddress;

    addressLength = sizeof(struct sockaddr_in);
    address = (struct sockaddr *)NSZoneMalloc(NULL, addressLength);
    if (getsockname(socketFD, address, &addressLength) == -1) {
        NSZoneFree(NULL, address);
	[NSException raise:ONInternetSocketGetNameFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Unable to get socket name: %s", strerror(OMNI_ERRNO())];
    }
    localAddress = (struct sockaddr_in *)address;
    return localAddress;
}

- (unsigned long int)localAddressHostNumber;
{
    return ntohl([self localAddress]->sin_addr.s_addr);
}

- (unsigned short int)localAddressPort;
{
    return ntohs([self localAddress]->sin_port);
}

//

- (ONHost *)remoteAddressHost;
{
    ONHostAddress *address;

    if (remoteHost)
	return remoteHost;

    address = [ONHostAddress hostAddressWithInternetAddress:&([self remoteAddress]->sin_addr)];
    remoteHost = [[ONHost hostForAddress:address] retain];
    return remoteHost;
}

- (const struct sockaddr_in *)remoteAddress;
{
    int addressLength;
    struct sockaddr *address;

    if (remoteAddress)
	return remoteAddress;

    addressLength = sizeof(struct sockaddr_in);
    address = (struct sockaddr *)NSZoneMalloc(NULL, addressLength);
    if (getpeername(socketFD, address, &addressLength) == -1) {
        NSZoneFree(NULL, address);
	[NSException raise:ONInternetSocketGetNameFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Unable to get socket name: %s", strerror(OMNI_ERRNO())];
    }
    remoteAddress = (struct sockaddr_in *)address;
    return remoteAddress;
}

- (unsigned long)remoteAddressHostNumber;
{
    return ntohl([self remoteAddress]->sin_addr.s_addr);
}

- (unsigned short)remoteAddressPort;
{
    return ntohs([self remoteAddress]->sin_port);
}

//

#ifdef HAVE_ONInterface
- (ONInterface *) localInterface;
{
    NSArray      *interfaces;
    unsigned int  interfaceCount;
    ONInterface  *interface;

    // Make sure the local address is cached
    [self localAddress];
    
    interfaces = [ONInterface interfaces];
    interfaceCount = [interfaces count];
    while (interfaceCount--) {
        const struct in_addr *interfaceAddress;
        
        interface = [interfaces objectAtIndex: interfaceCount];
        interfaceAddress = [[interface interfaceAddress] internetAddress];

        // In the future, we might want to hanle the case in which multiple network
        // interfaces have the same IP address.  In that case, we'd either need to
        // return an array of the possible interfaces, or look at a destination address
        // and the routing tables in order to determine which interface will get used.
        // There still might be multiple interfaces that might be sharing the load,
        // in which case it seems like we wouldn't be able to determine with any certainty
        // which one would get used.  The best solution would probably be to return an
        // array of possibilities and let the caller work with that set (for example, in
        // the case of the -maximumTransmissionUnit, they could just use the smallest unit.
        if (interfaceAddress->s_addr == localAddress->sin_addr.s_addr)
            return interface;
    }

    [NSException raise: NSInternalInconsistencyException
                format: @"No interface found matching local address for socket %@.", self];
    return nil;
}
#endif

//

- (void)connectToHost:(ONHost *)host serviceEntry:(ONServiceEntry *)service;
{
    [self connectToHost:host port:[service portNumber]];
}

- (void)connectToHost:(ONHost *)host port:(unsigned short int)port;
{
    NSException *firstTemporaryException = nil;
    NSArray *hostAddresses;
    NSEnumerator *hostAddressEnumerator;
    ONHostAddress *hostAddress;

    OBASSERT(!flags.connected);
    hostAddresses = [host addresses];
    if ([hostAddresses count] == 0)
	[NSException raise:ONInternetSocketConnectFailedExceptionName
             format:NSLocalizedStringFromTableInBundle(@"Unable to connect: no IP address for host '%@'", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), [host hostname]];
    hostAddressEnumerator = [hostAddresses objectEnumerator];
    while (!flags.connected &&
	   (hostAddress = [hostAddressEnumerator nextObject])) {
	NS_DURING {
	    [self connectToAddress:hostAddress port:port];
	} NS_HANDLER {
	    if (![[localException name] isEqualToString:
		  ONInternetSocketConnectTemporarilyFailedExceptionName])
		[localException raise];
	    if (!firstTemporaryException)
		firstTemporaryException = localException;
	} NS_ENDHANDLER;
    }
    if (!flags.connected)
	[firstTemporaryException raise];
}

- (void)connectToAddress:(ONHostAddress *)hostAddress port:(unsigned short int)port;
{
    BOOL connectSucceeded, socketCloseSucceeded;
    struct sockaddr_in socketAddress;
    
    OBASSERT(!flags.connected);
    if (socketFD == -1)
        socketFD = [isa createSocketFD];

    socketAddress.sin_family = AF_INET;
    socketAddress.sin_addr = *[hostAddress internetAddress];
    socketAddress.sin_port = htons(port);

    connectSucceeded = connect(socketFD, (struct sockaddr *)&socketAddress, sizeof(socketAddress)) == 0;

    // Check to see if the user aborted the connect()
    if (flags.userAbort)
        [NSException raise:ONInternetSocketUserAbortExceptionName format:NSLocalizedStringFromTableInBundle(@"Connect aborted", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error)];

    if (connectSucceeded) {
	flags.connected = YES;
    } else {
        OBASSERT(socketFD != -1);
        socketCloseSucceeded = OBSocketClose(socketFD) == 0;
        OBASSERT(socketCloseSucceeded);
        socketFD = -1;

        switch (OMNI_ERRNO()) {
            case ETIMEDOUT:
            case ECONNREFUSED:
            case ENETDOWN:
            case ENETUNREACH:
            case EHOSTDOWN:
            case EHOSTUNREACH:
                [NSException raise:ONInternetSocketConnectTemporarilyFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:NSLocalizedStringFromTableInBundle(@"Temporarily unable to connect: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
            default:
                [NSException raise:ONInternetSocketConnectFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:NSLocalizedStringFromTableInBundle(@"Unable to connect: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
        }
    }
}


- (BOOL)waitForInputWithTimeout:(NSTimeInterval)timeout;
{
    struct timeval selectTimeout;
    fd_set readfds;
    int returnValue;

    if (!flags.connected)
        [NSException raise:ONInternetSocketNotConnectedExceptionName format:@"Attempted read from a non-connected socket"];
    if (timeout < 0.0)
	timeout = 0.0;
    selectTimeout.tv_sec = timeout;
    selectTimeout.tv_usec = 1.0e6 * (timeout - selectTimeout.tv_sec);
    FD_ZERO(&readfds);
    FD_SET(socketFD, &readfds);
    returnValue = select(socketFD + 1, &readfds, NULL, NULL, &selectTimeout);
    switch (returnValue) {
        case -1:
            [NSException raise:ONInternetSocketReadFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:NSLocalizedStringFromTableInBundle(@"Error waiting for input: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
        case 0:
            return NO;
        default:
            return FD_ISSET(socketFD, &readfds) != 0;
    }
}

- (void)setAllowsBroadcast:(BOOL)shouldAllowBroadcast;
{
    int allows = 0;

    // convert BOOL to an int
    if (shouldAllowBroadcast)
        allows = 1;

    if (setsockopt(socketFD, SOL_SOCKET, SO_BROADCAST, (char *)&allows, sizeof(allows)) == -1)
        [NSException raise:ONInternetSocketBroadcastSelectionFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Failed to set broadcast to %d on socket: %s", allows, strerror(OMNI_ERRNO())];
}

- (int)socketFD;
{
    return socketFD;
}

- (BOOL)isConnected;
{
    return (flags.connected > 0) ? YES : NO;
}

- (BOOL)didAbort;
{
    return (flags.userAbort > 0) ? YES : NO;
}

- (BOOL)isWritable;
{
    fd_set writeSet;
    struct timeval timeout;

    if (socketFD == -1)
        return NO;
    
    timeout.tv_sec = 0;
    timeout.tv_usec = 0;
    FD_ZERO(&writeSet);
    FD_SET(socketFD, &writeSet);
    return select(socketFD + 1, NULL, &writeSet, NULL, &timeout) == 1;
}

- (BOOL)isReadable;
{
    fd_set readSet;
    struct timeval timeout;

    if (socketFD == -1)
        return NO;

    timeout.tv_sec = 0;
    timeout.tv_usec = 0;
    FD_ZERO(&readSet);
    FD_SET(socketFD, &readSet);
    return select(socketFD + 1, &readSet, NULL, NULL, &timeout) == 1;
}

// ONSocket subclass

- (void)abortSocket;
{
    BOOL socketCloseSucceeded;

    flags.userAbort = YES;
    if (socketFD != -1) {
        int oldSocketFD;

        oldSocketFD = socketFD;
        socketFD = -1;
        if (flags.connected)
            (void)shutdown(oldSocketFD, 2); // disallow further sends and receives
        socketCloseSucceeded = OBSocketClose(oldSocketFD) == 0;
        OBASSERT(socketCloseSucceeded);
    }
    flags.connected = NO;
}

// Debugging

- (NSMutableDictionary *)debugDictionary;
{
    NSMutableDictionary *debugDictionary;

    debugDictionary = [super debugDictionary];
    if (socketFD)
	[debugDictionary setObject:[NSNumber numberWithInt:socketFD] forKey:@"socketFD"];
    [debugDictionary setObject:flags.listening ? @"YES" : @"NO" forKey:@"listening"];
    [debugDictionary setObject:flags.connected ? @"YES" : @"NO" forKey:@"connected"];
    [debugDictionary setObject:flags.userAbort ? @"YES" : @"NO" forKey:@"userAbort"];

    return debugDictionary;
}

@end

@implementation ONInternetSocket (SubclassAPI)

- _initWithSocketFD:(int)aSocketFD connected:(BOOL)isConnected;
{
    if (![super init])
	return nil;

    socketFD = aSocketFD;
    flags.connected = isConnected;

    return self;
}

@end

@implementation ONInternetSocket (Private)

+ (int)createSocketFD;
{
    int newSocketFD;

    newSocketFD = socket([self protocolFamily], [self socketType], [self ipProtocol]);
    if (newSocketFD == -1)
        [NSException raise:ONInternetSocketConnectFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Unable to create a socket: %s", strerror(OMNI_ERRNO())];
    return newSocketFD;
}

@end


DEFINE_NSSTRING(ONInternetSocketBroadcastSelectionFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketReuseSelectionFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketBindFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketConnectFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketConnectTemporarilyFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketGetNameFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketNotConnectedExceptionName);
DEFINE_NSSTRING(ONInternetSocketReadFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketUserAbortExceptionName);
DEFINE_NSSTRING(ONInternetSocketWriteFailedExceptionName);
DEFINE_NSSTRING(ONInternetSocketCloseFailedExceptionName);
