// Copyright 1997-2000 Omni Development, Inc.  All rights reserved.
//
// This software may only be used and reproduced according to the
// terms in the file OmniSourceLicense.html, which should be
// distributed with this project and can also be found at
// http://www.omnigroup.com/DeveloperResources/OmniSourceLicense.html.

#import "ONUDPSocket.h"

#import <Foundation/Foundation.h>
#import <OmniBase/OmniBase.h>
#import <OmniBase/system.h>

#import "ONHost.h"
#import "ONPortAddress.h"

RCS_ID("$Header: /Network/Source/CVS/OmniGroup/Frameworks/OmniNetworking/ONUDPSocket.m,v 1.16 2000/11/16 12:47:48 wjs Exp $")

@implementation ONUDPSocket

- (unsigned int)writeBytes:(unsigned int)byteCount fromBuffer:(const void *)aBuffer toPortAddress:(ONPortAddress *)aPortAddress;
{
    // Note, you can be connected and still do a sendto()
    int bytesWritten;
    const struct sockaddr_in *portAddress;

    portAddress = [aPortAddress portAddress];
    bytesWritten = sendto(socketFD, (char *)aBuffer, byteCount, 0, (struct sockaddr *)portAddress, sizeof(*portAddress));
    if (bytesWritten == -1)
	[NSException raise:ONInternetSocketWriteFailedExceptionName format:NSLocalizedStringFromTableInBundle(@"Unable to write to socket: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
    return (unsigned int)bytesWritten;
}


// ONSocket subclass

- (unsigned int)readBytes:(unsigned int)byteCount intoBuffer:(void *)aBuffer;
{
    int bytesRead;

    if (flags.connected)
	bytesRead = recv(socketFD, aBuffer, byteCount, 0);
    else {
        int remoteAddressLength;

        remoteAddressLength = sizeof(*remoteAddress);
        if (!remoteAddress) {
            remoteAddress = NSZoneMalloc(NULL, remoteAddressLength);
        }

	bytesRead = recvfrom(socketFD, aBuffer, byteCount, 0, (struct sockaddr *)remoteAddress, &remoteAddressLength);
        OBASSERT(bytesRead == -1 || remoteAddressLength == 0 || remoteAddressLength == sizeof(*remoteAddress));
        if (bytesRead == -1 || remoteAddressLength == 0) {
            // We didn't really receive anything so make sure the remote address is nulled out.
            memset(remoteAddress, 0, sizeof(*remoteAddress));
        }
        // ONInternetSocket caches the remote host, flush that cache
        if (remoteHost) {
            [remoteHost release];
            remoteHost = nil;
        }
    }

    if (flags.userAbort)
        [NSException raise:ONInternetSocketUserAbortExceptionName format:NSLocalizedStringFromTableInBundle(@"Read aborted", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error)];

    if (bytesRead == -1)
	[NSException raise:ONInternetSocketReadFailedExceptionName format:NSLocalizedStringFromTableInBundle(@"Unable to read from socket: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
    return (unsigned int)bytesRead;
}

- (unsigned int)writeBytes:(unsigned int)byteCount fromBuffer:(const void *)aBuffer;
{
    int bytesWritten;

    if (!flags.connected)
	[NSException raise:ONInternetSocketNotConnectedExceptionName format:@"Attempted write to a non-connected socket"];
    bytesWritten = send(socketFD, (char *)aBuffer, byteCount, 0);
    if (bytesWritten == -1)
	[NSException raise:ONInternetSocketWriteFailedExceptionName format:NSLocalizedStringFromTableInBundle(@"Unable to write to socket: %s", @"OmniNetworking", [NSBundle bundleForClass:[self class]], error), strerror(OMNI_ERRNO())];
    return (unsigned int)bytesWritten;
}


// ONInternetSocket subclass

+ (int)socketType;
{
    return SOCK_DGRAM;
}

+ (int)ipProtocol;
{
    return IPPROTO_UDP;
}

- (const struct sockaddr_in *)remoteAddress;
{
    // The ONInternetSocket implementation of this tries to find the remote address by using getpeername(), which only works on connected sockets.
    // We subclass this method to return the remote address associated with the last read.
    return remoteAddress;
}

- (void)connectToHost:(ONHost *)host port:(unsigned short int)port;
{
    // ONInternetSocket's implementation will return without doing a connect if it's already connected.
    flags.connected = NO;
    [super connectToHost:host port:port];
}

@end
