
package common;

import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.BufferedReader;

import java.net.ServerSocket;

import java.rmi.server.RMIServerSocketFactory;

//import javax.net.ssl.*;
import java.security.KeyStore;
// import javax.net.ssl.*;

import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

/**
 * Class RMISSLServerSocketFactory is similar in its purpose to the default JDK
 * javax.rmi.ssl.SslRMIServerSocketFactory but differs in its ability
 * to initilalize key sources in non-standard way. 
 */
public class RMISSLServerSocketFactory implements RMIServerSocketFactory {
  public static boolean TRACE = false;
  static void out (String msg) { System.out.println("RMISSL Server Socket Factory: " + msg); }

  /*
   * Create one SSLServerSocketFactory, so we can reuse sessions
   * created by previous sessions of this SSLContext.
   */
  private SSLServerSocketFactory ssf = null;

  /* Keep track of the [current/last] server socket */
  private ServerSocket ss = null;
  public ServerSocket getServerSocket () { return ss; }


  // use default key and trust store initialization.
  // Key and trust store access are defined by system properties
  //    javax.net.ssl.keyStore
  //    javax.net.ssl.keyStorePassword
  //    javax.net.ssl.trustStore
  //    javax.net.ssl.trustStorePassword
  //
  // see JSEE Ref Guide, section 'Creating X509KeyManager' and also
  // 'Customizing the Default Key and Trust Stores, Store Types, and Store Passwords':
  //
  // Whenever a default SSLSocketFactory or SSLServerSocketFactory is
  // created (via a call to SSLSocketFactory.getDefault or
  // SSLServerSocketFactory.getDefault), and this default
  // SSLSocketFactory (or SSLServerSocketFactory) comes from the JSSE
  // reference implementation, a default SSLContext is associated with
  // the socket factory. (The default socket factory will come from the
  // JSSE implementation.)
  // 
  // This default SSLContext is initialized with a default KeyManager
  // and a TrustManager. If a keystore is specified by the
  // javax.net.ssl.keyStore system property, then the KeyManager created
  // by the default SSLContext will be a KeyManager implementation for
  // managing the specified keystore. (The actual implementation will be
  // as specified in Customizing the Default Key and Trust Managers.) If
  // no such system property is specified, then the keystore managed by
  // the KeyManager will be a new empty keystore.
  // 
  // Similarly, if a truststore is specified by the
  // javax.net.ssl.trustStore system property, then the TrustManager
  // created by the default SSLContext will be a TrustManager
  // implementation for managing the specified truststore. In this case,
  // if such a property exists but the file it specifies doesn't, then
  // no truststore is utilized. If no javax.net.ssl.trustStore property
  // exists, then a default truststore is searched for. If a truststore
  // named <java-home>/lib/security/jssecacerts is found, it is used. If
  // not, then a truststore named <java-home>/lib/security/cacerts is
  // searched for and used (if it exists). See The Installation
  // Directory <java-home> for information as to what <java-home> refers
  // to. Finally, if a truststore is still not found, then the
  // truststore managed by the TrustManager will be a new empty
  // truststore.
  // 

  static private void trace_default_SSL_settings () { 
    out("javax.net.ssl.keyStore: " + System.getProperty("javax.net.ssl.keyStore"));
    out("Set javax.net.ssl.keyStorePassword: " + System.getProperty("javax.net.ssl.keyStorePassword"));
    out("Set javax.net.ssl.trustStore: " + System.getProperty("javax.net.ssl.trustStore"));
    out("Set javax.net.ssl.trustStorePassword: " + System.getProperty("javax.net.ssl.trustStorePassword"));
  }

  public RMISSLServerSocketFactory(boolean isDefault) throws Exception {
    try {

      if (isDefault) {
        trace_default_SSL_settings();
        if (TRACE) out("Get Default ServerSocketFactory ...");
        ssf  = (SSLServerSocketFactory)SSLServerSocketFactory.getDefault();
        if (TRACE) out("Done Get Default ServerSocketFactory: " + ssf);

      } else {

        SSLContext ctx;
        KeyStore keystore, truststore;

        // explicitly set up key manager to do server authentication and 
        // trust  manager to do client authentication
        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");

        InputStream pwis = this.getClass().getResourceAsStream("/tmp/server_input");
        if (TRACE) out("-- pwis: " + pwis);
        BufferedReader reader = new BufferedReader(new InputStreamReader(pwis));
        char[] passphrase = reader.readLine().toCharArray();
        if (TRACE) out("passphrase is array of len: " + passphrase.length);
        reader.close();

        if (TRACE) out("Create keystore ...");
        keystore = KeyStore.getInstance("JKS");
        if (TRACE) out("Load keystore ...");
        String ksname = KeytoolAttrs.serverKeyStore();
        //InputStream ksis =  new FileInputStream(ksname);
        InputStream ksis =  this.getClass().getResourceAsStream(ksname);
        if (TRACE) out("*store: " + ksname + " stream: " + ksis);
        keystore.load(ksis, passphrase);
        if (TRACE) out("Done load keystore " + keystore);
        ksis.close();

        if (TRACE) out("Create truststore ...");
        truststore = KeyStore.getInstance("JKS");
        if (TRACE) out("Load truststore ...");
        ksname = KeytoolAttrs.serverTrustStore();
        // ksis = new FileInputStream(ksname);
        ksis =  this.getClass().getResourceAsStream(ksname);
        if (TRACE) out("*store: " + ksname + " stream: " + ksis);
        truststore.load(ksis, passphrase);
        if (TRACE) out("Done load truststore " + truststore);
        ksis.close();

        kmf.init(keystore, passphrase);
        tmf.init(truststore);
        if (TRACE) out("Done initializing stores.");

        if (TRACE) out("Create SSLContext ...");
        ctx = SSLContext.getInstance("TLS");
        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
        if (TRACE) out("Done create SSLContext " + ctx);

        if (TRACE) out("Get ServerSocketFactory ...");
        ssf = ctx.getServerSocketFactory();
        if (TRACE) out("Done get ServerSocketFactory " + ssf);
      }

    } catch (Exception e) {
      System.out.println("-- RMISSLServerSocketFactory: got " + e);
      e.printStackTrace();
      throw e;
    }
  }

  public ServerSocket createServerSocket(int port) throws IOException {
    if (ss == null) {
      if (TRACE) { 
        out("Creating ServerSocket for port " +port+ " using factory "+ ssf + " ...");
        new Exception("  --- called from --- ").printStackTrace();
      }
      ss = ssf.createServerSocket(port);
      if (TRACE) out("Done creating ServerSocket " + ss);
    }
    else {
      if (TRACE) out("reuse ServerSocket " + ss);
    }
    return ss;
  }

  public int hashCode() {
    return getClass().hashCode();
  }

  public boolean equals(Object obj) {
    if (obj == this) {
      return true;
    } else if (obj == null || getClass() != obj.getClass()) {
      return false;
    }
    return true;
  }

}

class KeytoolAttrs {
  static String serverStr = "server";
  static String clientStr = "client";
  // used for demo default SSL
  static char[] sec = 
    {'d', 'e', 'm', 'o', '_', 'p', 's', 'w', 'd'};
  static String storeDir = "/tmp/"; // TODO: convert to System property
  
  static String serverKeyStore() { return storeDir + serverStr + "_keystore"; }
  static String serverTrustStore() { return storeDir + serverStr + "_truststore"; }
  static String serverPW() { return serverStr + "_" + new String(sec);  }
  static String clientKeyStore() { return storeDir + clientStr + "_keystore"; }
  static String clientTrustStore() { return storeDir + clientStr + "_truststore"; }
  static String clientPW() { return clientStr + "_" + new String(sec);  }

}






