// Copyright 1997-2001 Omni Development, Inc.  All rights reserved.
//
// This software may only be used and reproduced according to the
// terms in the file OmniSourceLicense.html, which should be
// distributed with this project and can also be found at
// http://www.omnigroup.com/DeveloperResources/OmniSourceLicense.html.

#import "ONTCPSocket.h"

#import <Foundation/Foundation.h>
#import <OmniBase/OmniBase.h>
#import <OmniBase/system.h>

#import "ONHost.h"
#import "ONHostAddress.h"
#import "ONInternetSocket-Private.h"
#import "ONServiceEntry.h"

#import <errno.h>

#if !defined(WIN32)
#import <sys/fcntl.h>
#endif    


RCS_ID("$Header: /Network/Source/CVS/OmniGroup/Frameworks/OmniNetworking/ONTCPSocket.m,v 1.33 2001/08/18 01:14:02 kc Exp $")

@interface ONTCPSocket (Private)
- (int)socketFDForAcceptedConnection;
@end

@implementation ONTCPSocket

static Class defaultTCPSocketClass = nil;

+ (void)initialize;
{
    OBINITIALIZE;

    defaultTCPSocketClass = [ONTCPSocket class];

#ifndef WIN32
    // get rid of pesky SIGPIPE signals - we want an exception instead
    signal(SIGPIPE, SIG_IGN);
#endif
}

+ (Class)defaultTCPSocketClass;
{
    return defaultTCPSocketClass;
}

+ (void)setDefaultTCPSocketClass:(Class)aClass;
{
    // TODO: ASSERT that aClass is a subclass of ONTCPSocket
    defaultTCPSocketClass = aClass;
}

+ (ONTCPSocket *)tcpSocket;
{
    return (ONTCPSocket *)[defaultTCPSocketClass socket];
}

//

- (void)setNonBlocking:(BOOL)shouldBeNonBlocking;
{
#if defined(WIN32)
    #warning No nonblocking on win32
#elif defined(NeXT) || !defined(YELLOW_BOX)
    // BSD 4.3
    fcntl(socketFD, F_SETFL, shouldBeNonBlocking ? FNDELAY : 0x0);
#else
    // POSIX
    fcntl(socketFD, F_SETFL, shouldBeNonBlocking ? O_NONBLOCK : 0x0);
#endif    
}

- (void)startListeningOnAnyLocalPort;
{
    [self startListeningOnLocalPort:0];
}

#define PENDING_CONNECTION_LIMIT 5

- (void)startListeningOnLocalPort:(unsigned short int)port;
{
    [self startListeningOnLocalPort:port allowingAddressReuse:NO];
}

- (void)startListeningOnLocalPort:(unsigned short int)port allowingAddressReuse:(BOOL)reuse;
{
    [self setLocalPortNumber:port allowingAddressReuse:(BOOL)reuse];

    if (listen(socketFD, PENDING_CONNECTION_LIMIT) == -1)
        [NSException raise:ONTCPSocketListenFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Unable to listen on socket: %s", strerror(OMNI_ERRNO())];
    flags.listening = YES;
}

- (void)startListeningOnLocalService:(ONServiceEntry *)service;
{
    [self startListeningOnLocalPort:[service portNumber]];
}

- (void)acceptConnection;
{
    int newSocketFD;
    BOOL socketCloseSucceeded;

    newSocketFD = [self socketFDForAcceptedConnection];
    socketCloseSucceeded = OBSocketClose(socketFD) == 0;
    OBASSERT(socketCloseSucceeded);
    socketFD = newSocketFD;
    flags.connected = YES;
    flags.listening = NO;
}

- (ONTCPSocket *)acceptConnectionOnNewSocket;
{
    return [[[isa alloc] _initWithSocketFD:[self socketFDForAcceptedConnection] connected:YES] autorelease];
}

// ONInternetSocket subclass

+ (int)socketType;
{
    return SOCK_STREAM;
}

+ (int)ipProtocol;
{
    return IPPROTO_TCP;
}

- (unsigned int)readBytes:(unsigned int)byteCount intoBuffer:(void *)aBuffer;
{
    BOOL socketCloseSucceeded;
    int bytesRead;
    int read_errno;

    while (!flags.connected) {
	if (!flags.listening)
	    [NSException raise:ONInternetSocketNotConnectedExceptionName format:@"Attempted read from a non-connected socket"];
	else
	    [self acceptConnection];
    }
    bytesRead = OBSocketRead(socketFD, aBuffer, byteCount);
    switch (bytesRead) {
        case -1:
            if (flags.userAbort)
                [NSException raise:ONInternetSocketUserAbortExceptionName format:NSLocalizedStringFromTableInBundle(@"Read aborted", @"OmniNetworking", [NSBundle bundleForClass:[ONTCPSocket class]], error: userAbort)];
            // Error reading socket
            read_errno = OMNI_ERRNO();
            if (read_errno == EAGAIN)
                [NSException raise:ONTCPSocketWouldBlockExceptionName format:NSLocalizedStringFromTableInBundle(@"Read aborted", @"OmniNetworking", [NSBundle bundleForClass:[ONTCPSocket class]], error: EAGAIN)];
            if (read_errno == EPIPE)
                goto read_eof;
            [NSException raise:ONInternetSocketReadFailedExceptionName posixErrorNumber:read_errno format:NSLocalizedStringFromTableInBundle(@"Unable to read from socket: %s", @"OmniNetworking", [NSBundle bundleForClass:[ONTCPSocket class]], error), strerror(OMNI_ERRNO())];
            return 0; // Not reached
        case 0:
        read_eof:
            // Our peer closed the socket, resulting in an end-of-file.  Close it on this end.
            flags.connected = NO;
            // 0 can be returned when we closed the socket ourselves (from another thread), so check to see whether we still have a file descriptor.
            if (socketFD != -1) {
                socketCloseSucceeded = OBSocketClose(socketFD) == 0;
                OBASSERT(socketCloseSucceeded);
                socketFD = -1;
            }
            return 0;
        default:
            // Normal successful read
            return (unsigned int)bytesRead;
    }
}

- (unsigned int)writeBytes:(unsigned int)byteCount fromBuffer:(const void *)aBuffer;
{
    int bytesWritten;

    while (!flags.connected) {
	if (!flags.listening)
	    [NSException raise:ONInternetSocketNotConnectedExceptionName
	     format:@"Attempted write to a non-connected socket"];
	else
	    [self acceptConnection];
    }
#ifndef MAX_BYTES_PER_WRITE
    bytesWritten = OBSocketWrite(socketFD, aBuffer, byteCount);
#else
    bytesWritten = OBSocketWrite(socketFD, aBuffer, byteCount > MAX_BYTES_PER_WRITE ? MAX_BYTES_PER_WRITE : byteCount);
#endif
    if (bytesWritten == -1) {
        if (flags.userAbort)
            [NSException raise:ONInternetSocketUserAbortExceptionName format:NSLocalizedStringFromTableInBundle(@"Write aborted", @"OmniNetworking", [NSBundle bundleForClass:[ONTCPSocket class]], error: userAbort)];
        if (OMNI_ERRNO() == EAGAIN)
            [NSException raise:ONTCPSocketWouldBlockExceptionName format:NSLocalizedStringFromTableInBundle(@"Write aborted", @"OmniNetworking", [NSBundle bundleForClass:[ONTCPSocket class]], error: EAGAIN)];
        [NSException raise:ONInternetSocketWriteFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:NSLocalizedStringFromTableInBundle(@"Unable to write to socket: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
    }
    
    return (unsigned int)bytesWritten;
}

@end

@implementation ONTCPSocket (Private)

- (int)socketFDForAcceptedConnection;
{
    int newSocketFD;
    struct sockaddr_in acceptAddress;
    int acceptAddressLength;

    acceptAddressLength = sizeof(struct sockaddr_in);
    do {
	newSocketFD = accept(socketFD, (struct sockaddr *)&acceptAddress, &acceptAddressLength);
    } while (newSocketFD == -1 && OMNI_ERRNO() == EINTR);

    if (newSocketFD == -1)
	[NSException raise:ONTCPSocketAcceptFailedExceptionName posixErrorNumber:OMNI_ERRNO() format:@"Socket accept failed: %s", strerror(OMNI_ERRNO())];
    if (!remoteAddress)
        remoteAddress = NSZoneMalloc(NULL, sizeof(struct sockaddr_in));
    *remoteAddress = acceptAddress;
    return newSocketFD;
}

@end


DEFINE_NSSTRING(ONTCPSocketListenFailedExceptionName);
DEFINE_NSSTRING(ONTCPSocketAcceptFailedExceptionName);
DEFINE_NSSTRING(ONTCPSocketWouldBlockExceptionName);
