 
package common;

import java.io.*;
import java.net.*;
import java.rmi.server.*;
import javax.net.ssl.*;


import java.security.KeyStore;


/**
 * Class RMISSLClientSocketFactory is similar in its purpose to the
 * default JDK javax.rmi.ssl.SslRMIClientSocketFactory but differs in
 * its ability to initilalize key sources in non-standard way.
 * 
 * Mich like javax.rmi.ssl.SslRMIClientSocketFactory, and for the same
 * reason, RMISSLClientSocketFactory does not initilalize the
 * internals at constructor call time deferring that to createSocket
 * time: because the RMI client socket factory is created on the
 * server side, where that initialization is a priori meaningless,
 * unless both server and client run in the same JVM.  We could
 * possibly override readObject() to force this initialization, but it
 * might not be a good idea to actually mix this with possible
 * deserialization problems.  So contrarily to what we do for the
 * server side, the initialization of the SSLSocketFactory will be
 * delayed until the first time createSocket() is called - note that
 * the default SSLSocketFactory might already have been initialized
 * anyway if someone in the JVM already called
 * SSLSocketFactory.getDefault().
 */

public class RMISSLClientSocketFactory implements RMIClientSocketFactory, Serializable {
  public static boolean TRACE = false;
  static void out (String msg) { System.out.println("RMISSL Client Socket Factory: " + msg); }

  // Note that after a client connects to a server, a factory matching
  // to what was obtained from the server will be initialized. The
  // value of isDefaultKM will match to what is on the server. For
  // instance, the client may initially use rmis_ssl_dflt but the
  // server may use the prog variant forcing the client to switch to
  // that, and if the credentials are on the path, this will continue
  // OK.
  boolean isDefaultKM;

  public RMISSLClientSocketFactory (boolean isDefaultKM) {
    this.isDefaultKM = isDefaultKM;
    if (TRACE) out("Creating RMISSLClientSocketFactory." + 
                   (isDefaultKM ? " with default KM" : "with custom KM"));
  }
  public RMISSLClientSocketFactory () { this(true); }

  public SSLSocketFactory initFactory()  {
    if (TRACE) out("Enter initFactory. isDefaultKM = " + isDefaultKM);
    // if (TRACE) new Exception("initFactory's stack").printStackTrace(System.out);
    SSLSocketFactory factory = null;
    if (isDefaultKM) {
      // this method of obtaining factory and socket assumes the default
      // Key and Trust Store factory managers (see JSSE) specified as
      // -Djavax.net.ssl.keyStore=keystore
      // -Djavax.net.ssl.keyStorePassword=password 
      // -Djavax.net.ssl.trustStore=truststore
      // -Djavax.net.ssl.trustStorePassword=trustword 
      // when invoking this code

      if (TRACE) out("javax.net.ssl.keyStore: " + System.getProperty("javax.net.ssl.keyStore"));
      String keyStorePassword = System.getProperty("javax.net.ssl.keyStorePassword");
      if (TRACE) out("javax.net.ssl.keyStorePassword: " + keyStorePassword);
      if (TRACE) out("javax.net.ssl.trustStore: " + System.getProperty("javax.net.ssl.trustStore"));
      String trustStorePassword = System.getProperty("javax.net.ssl.trustStorePassword");
      if (TRACE) out("javax.net.ssl.trustStorePassword: " + trustStorePassword);

      if (TRACE) out("createSocket: Get default socket factory...");
      factory = (SSLSocketFactory)SSLSocketFactory.getDefault();
      if (TRACE) out("Done get SSLSocketFactory " + factory + " hash: " + factory.hashCode());
    }
    else {
      try {

        SSLContext ctx;
        KeyStore keystore, truststore;

        // explicitly set up key manager to do client authentication and 
        // explicitly set up trust  manager to do server authentication
        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");

        InputStream pwis = this.getClass().getResourceAsStream("/tmp/client_input");
        if (TRACE) out("-- pwis: " + pwis);
        BufferedReader reader = new BufferedReader(new InputStreamReader(pwis));
        String passphrase_str = reader.readLine();
        if (TRACE) out("passphrase_str: " + passphrase_str);
        char[] passphrase = passphrase_str.toCharArray();
        if (TRACE) out("passphrase is array of len: " + passphrase.length);
        reader.close();

        if (TRACE) out("Create keystore ...");
        keystore = KeyStore.getInstance("JKS");
        if (TRACE) out("Load keystore ...");
        String ksname = KeytoolAttrs.clientKeyStore();
        InputStream ksis =  this.getClass().getResourceAsStream(ksname);
        if (TRACE) out("*store: " + ksname + " stream: " + ksis);
        keystore.load(ksis, passphrase);
        keystore.load(new FileInputStream(KeytoolAttrs.clientKeyStore()), passphrase);
        if (TRACE) out("Done load keystore " + keystore);

        if (TRACE) out("Create truststore ...");
        truststore = KeyStore.getInstance("JKS");
        if (TRACE) out("Load truststore ...");
        ksname = KeytoolAttrs.clientTrustStore();
        // ksis = new FileInputStream(ksname);
        ksis =  this.getClass().getResourceAsStream(ksname);
        if (TRACE) out("*store: " + ksname + " stream: " + ksis);
        truststore.load(ksis, passphrase);
        if (TRACE) out("Done load truststore " + truststore);

        kmf.init(keystore, passphrase);
        tmf.init(truststore);

        if (TRACE) out("Create SSLContext ...");
        ctx = SSLContext.getInstance("TLS");
        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
        if (TRACE) out("Done create SSLContext " + ctx);

        if (TRACE) out("Get SSLSocketFactory ...");
        factory = ctx.getSocketFactory();
        if (TRACE) out("Done get SSLSocketFactory " + factory);

        if (TRACE) out("Compare that with the default socket factory...");
        SSLSocketFactory factory_dflt = (SSLSocketFactory)SSLSocketFactory.getDefault();
        if (TRACE) out("Done get SSLSocketFactory " + factory_dflt + " hash: " + factory_dflt.hashCode());


      } catch (Exception e) {
        if (TRACE) {
          out("RMISSLClientSocketFactory: got " + e);
          e.printStackTrace();
        }
        //throw new IOException("cause:"+e);
      }
    }
    return factory;
  }

  public Socket createSocket(String host, int port) throws IOException {
    // SSLSocketFactory factory = (SSLSocketFactory)SSLSocketFactory.getDefault();
    SSLSocketFactory factory = initFactory();
    if (TRACE) out("createSocket: Done get SSLSocketFactory " + factory);
    
    if (TRACE) out("Use " + factory + " to create SSLSocket on port " + port);
    SSLSocket s = (SSLSocket)factory.createSocket(host, port);
    if (TRACE) out("Got SSLSocket " + s + " class:" + s.getClass());
    return s;
  }

  public int hashCode() {
    return getClass().hashCode();
  }

  public boolean equals(Object obj) {
    if (TRACE) out("Enter equals");
    if (obj == this) {
      return true;
    } else if (obj == null || getClass() != obj.getClass()) {
      return false;
    }
    return true;
  }
}

