csclient.cxx - systemtap
Data types defined
Functions defined
Macros defined
Source code
#include "config.h"
#if HAVE_NSS
#include "session.h"
#include "cscommon.h"
#include "csclient.h"
#include "util.h"
#include "stap-probe.h"
#include <sys/times.h>
#include <vector>
#include <fstream>
#include <sstream>
#include <cassert>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
extern "C" {
#include <unistd.h>
#include <linux/limits.h>
#include <sys/time.h>
#include <glob.h>
#include <limits.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <net/if.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <pwd.h>
}
#if HAVE_AVAHI
extern "C" {
#include <avahi-client/client.h>
#include <avahi-client/lookup.h>
#include <avahi-common/simple-watch.h>
#include <avahi-common/malloc.h>
#include <avahi-common/error.h>
#include <avahi-common/timeval.h>
}
#endif
extern "C" {
#include <ssl.h>
#include <nspr.h>
#include <nss.h>
#include <certdb.h>
#include <pk11pub.h>
#include <prerror.h>
#include <secerr.h>
#include <sslerr.h>
}
#include "nsscommon.h"
using namespace std;
#define STAP_CSC_01 _("WARNING: The domain name, %s, does not match the DNS name(s) on the server certificate:\n")
#define STAP_CSC_02 _("could not find input file %s\n")
#define STAP_CSC_03 _("could not open input file %s\n")
#define STAP_CSC_04 _("Unable to open output file %s\n")
#define STAP_CSC_05 _("could not write to %s\n")
#define MOK_PUBLIC_CERT_NAME "signing_key.x509"
static PRIPv6Addr ©Address (PRIPv6Addr &PRin6, const in6_addr &in6);
static PRNetAddr ©NetAddr (PRNetAddr &x, const PRNetAddr &y);
bool operator!= (const PRNetAddr &x, const PRNetAddr &y);
bool operator== (const PRNetAddr &x, const PRNetAddr &y);
extern "C"
void
nsscommon_error (const char *msg, int logit __attribute ((unused)))
{
clog << msg << endl << flush;
}
struct compile_server_info
{
compile_server_info () : port(0), fully_specified(false)
{
memset (& address, 0, sizeof (address));
}
string host_name;
PRNetAddr address;
unsigned short port;
bool fully_specified;
string version;
string sysinfo;
string certinfo;
vector<string> mok_fingerprints;
bool empty () const
{
return this->host_name.empty () && ! this->hasAddress () && certinfo.empty ();
}
bool hasAddress () const
{
return this->address.raw.family != 0;
}
unsigned short setAddressPort (unsigned short port)
{
if (this->address.raw.family == PR_AF_INET)
return this->address.inet.port = htons (port);
if (this->address.raw.family == PR_AF_INET6)
return this->address.ipv6.port = htons (port);
assert (0);
return 0;
}
bool isComplete () const
{
return this->hasAddress () && port != 0;
}
bool operator== (const compile_server_info &that) const
{
if ((! this->hasAddress() && this->version.empty () &&
this->sysinfo.empty () && this->certinfo.empty ()) ||
(! that.hasAddress() && that.version.empty () &&
that.sysinfo.empty () && that.certinfo.empty ()))
{
if (this->host_name != that.host_name)
return false;
}
if (this->hasAddress() && that.hasAddress() &&
this->address != that.address)
return false;
if (this->port != 0 && that.port != 0 &&
this->port != that.port)
return false;
if (! this->version.empty () && ! that.version.empty () &&
this->version != that.version)
return false;
if (! this->sysinfo.empty () && ! that.sysinfo.empty () &&
this->sysinfo != that.sysinfo)
return false;
if (! this->certinfo.empty () && ! that.certinfo.empty () &&
this->certinfo != that.certinfo)
return false;
if (! this->mok_fingerprints.empty () && ! that.mok_fingerprints.empty ()
&& this->mok_fingerprints != that.mok_fingerprints)
return false;
return true; }
bool operator< (const compile_server_info &that) const
{
cs_protocol_version this_version (this->version.c_str ());
cs_protocol_version that_version (that.version.c_str ());
return that_version < this_version;
}
};
ostream &operator<< (ostream &s, const compile_server_info &i);
ostream &operator<< (ostream &s, const vector<compile_server_info> &v);
static void
preferred_order (vector<compile_server_info> &servers)
{
if (servers.size () < 2)
return;
sort (servers.begin (), servers.end ());
}
struct resolved_host {
string host_name;
PRNetAddr address;
resolved_host(string chost_name, PRNetAddr caddress):
host_name(chost_name), address(caddress) {}
};
struct compile_server_cache
{
vector<compile_server_info> default_servers;
vector<compile_server_info> specified_servers;
vector<compile_server_info> trusted_servers;
vector<compile_server_info> signing_servers;
vector<compile_server_info> online_servers;
vector<compile_server_info> all_servers;
map<string,vector<resolved_host> > resolved_hosts;
};
enum compile_server_properties {
compile_server_all = 0x1,
compile_server_trusted = 0x2,
compile_server_online = 0x4,
compile_server_compatible = 0x8,
compile_server_signer = 0x10,
compile_server_specified = 0x20
};
static compile_server_cache* cscache(systemtap_session& s);
static void query_server_status (systemtap_session &s, const string &status_string);
static void get_server_info (systemtap_session &s, int pmask, vector<compile_server_info> &servers);
static void get_all_server_info (systemtap_session &s, vector<compile_server_info> &servers);
static void get_default_server_info (systemtap_session &s, vector<compile_server_info> &servers);
static void get_specified_server_info (systemtap_session &s, vector<compile_server_info> &servers, bool no_default = false);
static void get_or_keep_online_server_info (systemtap_session &s, vector<compile_server_info> &servers, bool keep);
static void get_or_keep_trusted_server_info (systemtap_session &s, vector<compile_server_info> &servers, bool keep);
static void get_or_keep_signing_server_info (systemtap_session &s, vector<compile_server_info> &servers, bool keep);
static void get_or_keep_compatible_server_info (systemtap_session &s, vector<compile_server_info> &servers, bool keep);
static void keep_common_server_info (const compile_server_info &info_to_keep, vector<compile_server_info> &filtered_info);
static void keep_common_server_info (const vector<compile_server_info> &info_to_keep, vector<compile_server_info> &filtered_info);
static void keep_server_info_with_cert_and_port (systemtap_session &s, const compile_server_info &server, vector<compile_server_info> &servers);
static void add_server_info (const compile_server_info &info, vector<compile_server_info>& list);
static void add_server_info (const vector<compile_server_info> &source, vector<compile_server_info> &target);
static void merge_server_info (const compile_server_info &source, compile_server_info &target);
#if 0#endif
static void resolve_host (systemtap_session& s, compile_server_info &server, vector<compile_server_info> &servers);
#define SUCCESS 0
#define GENERAL_ERROR 1
#define CA_CERT_INVALID_ERROR 2
#define SERVER_CERT_EXPIRED_ERROR 3
static void add_server_trust (systemtap_session &s, const string &cert_db_path, vector<compile_server_info> &server_list);
static void revoke_server_trust (systemtap_session &s, const string &cert_db_path, const vector<compile_server_info> &server_list);
static void get_server_info_from_db (systemtap_session &s, vector<compile_server_info> &servers, const string &cert_db_path);
static string global_client_cert_db_path () {
return SYSCONFDIR "/systemtap/ssl/client";
}
static string
private_ssl_cert_db_path ()
{
return local_client_cert_db_path ();
}
static string
global_ssl_cert_db_path ()
{
return global_client_cert_db_path ();
}
static string
signing_cert_db_path ()
{
return SYSCONFDIR "/systemtap/staprun";
}
typedef struct connectionState_t
{
const char *hostName;
PRNetAddr addr;
const char *infileName;
const char *outfileName;
const char *trustNewServerMode;
} connectionState_t;
#if 0
#endif
static SECStatus
trustNewServer (CERTCertificate *serverCert)
{
SECStatus secStatus;
CERTCertTrust *trust = NULL;
PK11SlotInfo *slot = NULL;
slot = PK11_GetInternalKeySlot();
const char *nickname = server_cert_nickname ();
secStatus = PK11_ImportCert(slot, serverCert, CK_INVALID_HANDLE, nickname, PR_FALSE);
if (secStatus != SECSuccess)
goto done;
trust = (CERTCertTrust *)PORT_ZAlloc(sizeof(CERTCertTrust));
if (! trust)
{
secStatus = SECFailure;
goto done;
}
secStatus = CERT_DecodeTrustString(trust, "P,P,P");
if (secStatus != SECSuccess)
goto done;
secStatus = CERT_ChangeCertTrust(CERT_GetDefaultCertDB(), serverCert, trust);
done:
if (slot)
PK11_FreeSlot (slot);
if (trust)
PORT_Free(trust);
return secStatus;
}
static SECStatus
badCertHandler(void *arg, PRFileDesc *sslSocket)
{
SECStatus secStatus;
PRErrorCode errorNumber;
CERTCertificate *serverCert = NULL;
SECItem subAltName;
PRArenaPool *tmpArena = NULL;
CERTGeneralName *nameList, *current;
char *expected = NULL;
const connectionState_t *connectionState = (connectionState_t *)arg;
errorNumber = PR_GetError ();
switch (errorNumber)
{
case SSL_ERROR_BAD_CERT_DOMAIN:
secStatus = SECSuccess;
expected = SSL_RevealURL (sslSocket);
if (expected == NULL || *expected == '\0')
break;
fprintf (stderr, STAP_CSC_01, expected);
subAltName.data = NULL;
serverCert = SSL_PeerCertificate (sslSocket);
secStatus = CERT_FindCertExtension (serverCert,
SEC_OID_X509_SUBJECT_ALT_NAME,
& subAltName);
if (secStatus != SECSuccess || ! subAltName.data)
{
fprintf (stderr, _("Unable to find alt name extension on the server certificate\n"));
secStatus = SECSuccess; break;
}
tmpArena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
if (! tmpArena)
{
fprintf (stderr, _("Out of memory\n"));
SECITEM_FreeItem(& subAltName, PR_FALSE);
secStatus = SECSuccess; break;
}
nameList = CERT_DecodeAltNameExtension (tmpArena, & subAltName);
SECITEM_FreeItem(& subAltName, PR_FALSE);
if (! nameList)
{
fprintf (stderr, _("Unable to decode alt name extension on server certificate\n"));
secStatus = SECSuccess; break;
}
current = nameList;
do
{
if (current->type == certDNSName)
{
fprintf (stderr, " %.*s\n",
(int)current->name.other.len, current->name.other.data);
}
current = CERT_GetNextGeneralName (current);
}
while (current != nameList);
break;
case SEC_ERROR_CA_CERT_INVALID:
secStatus = SECFailure; if (! connectionState->trustNewServerMode)
break;
if (strcmp (connectionState->trustNewServerMode, "session") == 0)
{
secStatus = SECSuccess;
break;
}
if (strcmp (connectionState->trustNewServerMode, "permanent") == 0)
{
serverCert = SSL_PeerCertificate (sslSocket);
if (serverCert != NULL)
{
secStatus = trustNewServer (serverCert);
}
}
break;
default:
secStatus = SECFailure; break;
}
if (expected)
PORT_Free (expected);
if (tmpArena)
PORT_FreeArena (tmpArena, PR_FALSE);
if (serverCert != NULL)
{
CERT_DestroyCertificate (serverCert);
}
return secStatus;
}
static PRFileDesc *
setupSSLSocket (connectionState_t *connectionState)
{
PRFileDesc *tcpSocket;
PRFileDesc *sslSocket;
PRSocketOptionData socketOption;
PRStatus prStatus;
SECStatus secStatus;
tcpSocket = PR_OpenTCPSocket(connectionState->addr.raw.family);
if (tcpSocket == NULL)
goto loser;
socketOption.option = PR_SockOpt_Nonblocking;
socketOption.value.non_blocking = PR_FALSE;
prStatus = PR_SetSocketOption(tcpSocket, &socketOption);
if (prStatus != PR_SUCCESS)
goto loser;
sslSocket = SSL_ImportFD(NULL, tcpSocket);
if (!sslSocket)
goto loser;
secStatus = SSL_OptionSet(sslSocket, SSL_SECURITY, PR_TRUE);
if (secStatus != SECSuccess)
goto loser;
secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE);
if (secStatus != SECSuccess)
goto loser;
#if 0#endif
#if 0#endif
secStatus = SSL_BadCertHook(sslSocket, (SSLBadCertHandler)badCertHandler,
connectionState);
if (secStatus != SECSuccess)
goto loser;
#if 0#endif
return sslSocket;
loser:
if (tcpSocket)
PR_Close(tcpSocket);
return NULL;
}
static SECStatus
handle_connection (PRFileDesc *sslSocket, connectionState_t *connectionState)
{
PRInt32 numBytes;
char *readBuffer;
PRFileInfo info;
PRFileDesc *local_file_fd;
PRStatus prStatus;
SECStatus secStatus = SECSuccess;
#define READ_BUFFER_SIZE (60 * 1024)
if (! connectionState->infileName || ! connectionState->outfileName)
{
numBytes = htonl ((PRInt32)0);
numBytes = PR_Write (sslSocket, & numBytes, sizeof (numBytes));
if (numBytes < 0)
return SECFailure;
return SECSuccess;
}
prStatus = PR_GetFileInfo(connectionState->infileName, &info);
if (prStatus != PR_SUCCESS ||
info.type != PR_FILE_FILE ||
info.size < 0)
{
fprintf (stderr, STAP_CSC_02,
connectionState->infileName);
return SECFailure;
}
local_file_fd = PR_Open(connectionState->infileName, PR_RDONLY, 0);
if (local_file_fd == NULL)
{
fprintf (stderr, STAP_CSC_03, connectionState->infileName);
return SECFailure;
}
numBytes = htonl ((PRInt32)info.size);
numBytes = PR_Write(sslSocket, & numBytes, sizeof (numBytes));
if (numBytes < 0)
{
PR_Close(local_file_fd);
return SECFailure;
}
numBytes = PR_TransmitFile(sslSocket, local_file_fd,
NULL, 0,
PR_TRANSMITFILE_KEEP_OPEN,
PR_INTERVAL_NO_TIMEOUT);
if (numBytes < 0)
{
PR_Close(local_file_fd);
return SECFailure;
}
PR_Close(local_file_fd);
readBuffer = (char *)PORT_Alloc(READ_BUFFER_SIZE);
if (! readBuffer) {
fprintf (stderr, _("Out of memory\n"));
return SECFailure;
}
local_file_fd = PR_Open(connectionState->outfileName, PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE,
PR_IRUSR | PR_IWUSR | PR_IRGRP | PR_IWGRP | PR_IROTH);
if (local_file_fd == NULL)
{
fprintf (stderr, STAP_CSC_04, connectionState->outfileName);
return SECFailure;
}
while (PR_TRUE)
{
numBytes = PR_Read (sslSocket, readBuffer, READ_BUFFER_SIZE);
if (numBytes == 0)
break;
if (numBytes < 0)
{
secStatus = SECFailure;
break;
}
numBytes = PR_Write(local_file_fd, readBuffer, numBytes);
if (numBytes < 0)
{
fprintf (stderr, STAP_CSC_05, connectionState->outfileName);
secStatus = SECFailure;
break;
}
}
PR_Free(readBuffer);
PR_Close(local_file_fd);
return secStatus;
}
static SECStatus
do_connect (connectionState_t *connectionState)
{
PRFileDesc *sslSocket;
PRStatus prStatus;
SECStatus secStatus;
secStatus = SECSuccess;
sslSocket = setupSSLSocket (connectionState);
if (sslSocket == NULL)
return SECFailure;
#if 0#endif
secStatus = SSL_SetURL(sslSocket, connectionState->hostName);
if (secStatus != SECSuccess)
goto done;
prStatus = PR_Connect(sslSocket, & connectionState->addr, PR_INTERVAL_NO_TIMEOUT);
if (prStatus != PR_SUCCESS)
{
secStatus = SECFailure;
goto done;
}
secStatus = SSL_ResetHandshake(sslSocket, PR_FALSE);
if (secStatus != SECSuccess)
goto done;
secStatus = SSL_ForceHandshake(sslSocket);
if (secStatus != SECSuccess)
goto done;
secStatus = handle_connection(sslSocket, connectionState);
done:
prStatus = PR_Close(sslSocket);
return secStatus;
}
static bool
isIPv6LinkLocal (const PRNetAddr &address)
{
if (address.raw.family == PR_AF_INET6 &&
address.ipv6.ip.pr_s6_addr[0] == 0xfe && address.ipv6.ip.pr_s6_addr[1] == 0x80)
return true;
return false;
}
int
client_connect (const compile_server_info &server,
const char* infileName, const char* outfileName,
const char* trustNewServer)
{
SECStatus secStatus;
PRErrorCode errorNumber;
int attempt;
int errCode = GENERAL_ERROR;
struct connectionState_t connectionState;
memset (& connectionState, 0, sizeof (connectionState));
connectionState.hostName = server.host_name.c_str ();
connectionState.addr = server.address;
connectionState.infileName = infileName;
connectionState.outfileName = outfileName;
connectionState.trustNewServerMode = trustNewServer;
for (attempt = 0; attempt < 5; ++attempt)
{
secStatus = do_connect (& connectionState);
if (secStatus == SECSuccess)
return SUCCESS;
errorNumber = PR_GetError ();
switch (errorNumber)
{
case PR_CONNECT_RESET_ERROR:
sleep (1);
break; case SEC_ERROR_EXPIRED_CERTIFICATE:
errCode = SERVER_CERT_EXPIRED_ERROR;
return errCode;
case SEC_ERROR_CA_CERT_INVALID:
errCode = CA_CERT_INVALID_ERROR;
return errCode;
default:
return errCode;
}
}
return errCode;
}
int
compile_server_client::passes_0_4 ()
{
PROBE1(stap, client__start, &s);
if (s.verbose || ! s.auto_server_msgs.empty ())
clog << _("Using a compile server.") << endl;
struct tms tms_before;
times (& tms_before);
struct timeval tv_before;
gettimeofday (&tv_before, NULL);
int rc = initialize ();
assert_no_interrupts();
if (rc != 0) goto done;
rc = create_request ();
assert_no_interrupts();
if (rc != 0) goto done;
rc = package_request ();
assert_no_interrupts();
if (rc != 0) goto done;
rc = find_and_connect_to_server ();
assert_no_interrupts();
if (rc != 0) goto done;
rc = unpack_response ();
assert_no_interrupts();
if (rc != 0) goto done;
rc = process_response ();
done:
struct tms tms_after;
times (& tms_after);
unsigned _sc_clk_tck = sysconf (_SC_CLK_TCK);
struct timeval tv_after;
gettimeofday (&tv_after, NULL);
#define TIMESPRINT "in " << \
(tms_after.tms_cutime + tms_after.tms_utime \
- tms_before.tms_cutime - tms_before.tms_utime) * 1000 / (_sc_clk_tck) << "usr/" \
<< (tms_after.tms_cstime + tms_after.tms_stime \
- tms_before.tms_cstime - tms_before.tms_stime) * 1000 / (_sc_clk_tck) << "sys/" \
<< ((tv_after.tv_sec - tv_before.tv_sec) * 1000 + \
((long)tv_after.tv_usec - (long)tv_before.tv_usec) / 1000) << "real ms."
if (rc == 0)
{
if (s.last_pass == 4)
s.save_module = true;
if (! pending_interrupts)
{
if (s.save_module)
{
string module_src_path = s.tmpdir + "/" + s.module_filename();
string module_dest_path = s.module_filename();
copy_file (module_src_path, module_dest_path, s.verbose >= 3);
module_src_path += ".sgn";
if (file_exists (module_src_path))
{
module_dest_path += ".sgn";
copy_file(module_src_path, module_dest_path, s.verbose >= 3);
}
}
if (s.last_pass == 4)
{
cout << s.module_filename() << endl;
}
}
}
if (s.verbose)
{
string ws = s.winning_server;
if (ws == "") ws = "?";
clog << _("Passes: via server ") << ws << " "
<< getmemusage()
<< TIMESPRINT
<< endl;
}
if (rc && !s.dump_mode)
{
clog << _("Passes: via server failed. Try again with another '-v' option.") << endl;
}
PROBE1(stap, client__end, &s);
return rc;
}
int
compile_server_client::initialize ()
{
int rc = 0;
argc = 0;
private_ssl_dbs.push_back (private_ssl_cert_db_path ());
public_ssl_dbs.push_back (global_ssl_cert_db_path ());
client_tmpdir = s.tmpdir + "/client";
rc = create_dir (client_tmpdir.c_str ());
if (rc != 0)
{
const char* e = strerror (errno);
clog << _("ERROR: cannot create temporary directory (\"")
<< client_tmpdir << "\"): " << e
<< endl;
}
return rc;
}
int
compile_server_client::create_request ()
{
int rc = write_to_file (client_tmpdir + "/version", CURRENT_CS_PROTOCOL_VERSION);
if (rc != 0)
return rc;
if (s.script_file != "")
{
if (s.script_file == "-")
{
string packaged_script_dir = client_tmpdir + "/script";
rc = create_dir (packaged_script_dir.c_str ());
if (rc != 0)
{
const char* e = strerror (errno);
clog << _("ERROR: cannot create temporary directory ")
<< packaged_script_dir << ": " << e
<< endl;
return rc;
}
rc = ! copy_file("/dev/stdin", packaged_script_dir + "/-");
if (rc != 0)
return rc;
rc = add_package_arg ("script/-");
if (rc != 0)
return rc;
}
else
{
rc = include_file_or_directory ("script", s.script_file);
if (rc != 0)
return rc;
}
}
if (s.include_arg_start != -1)
{
unsigned limit = s.include_path.size ();
for (unsigned i = s.include_arg_start; i < limit; ++i)
{
rc = add_package_arg ("-I");
if (rc != 0)
return rc;
rc = include_file_or_directory ("tapset", s.include_path[i]);
if (rc != 0)
return rc;
}
}
rc = add_package_args ();
if (rc != 0)
return rc;
string sysinfo = "sysinfo: " + s.kernel_release + " " + s.architecture;
rc = write_to_file (client_tmpdir + "/sysinfo", sysinfo);
if (rc != 0)
return rc;
rc = add_localization_variables();
if (! s.mok_fingerprints.empty())
{
ostringstream fingerprints;
vector<string>::const_iterator it;
for (it = s.mok_fingerprints.begin(); it != s.mok_fingerprints.end();
it++)
fingerprints << *it << endl;
rc = write_to_file(client_tmpdir + "/mok_fingerprints",
fingerprints.str());
if (rc != 0)
return rc;
}
return rc;
}
int
compile_server_client::add_package_args ()
{
int rc = 0;
unsigned limit = s.server_args.size();
for (unsigned i = 0; i < limit; ++i)
{
rc = add_package_arg (s.server_args[i]);
if (rc != 0)
return rc;
}
limit = s.args.size();
if (limit > 0) {
rc = add_package_arg ("--");
if (rc != 0)
return rc;
for (unsigned i = 0; i < limit; ++i)
{
rc = add_package_arg (s.args[i]);
if (rc != 0)
return rc;
}
}
return rc;
}
int
compile_server_client::add_package_arg (const string &arg)
{
int rc = 0;
ostringstream fname;
fname << client_tmpdir << "/argv" << ++argc;
write_to_file (fname.str (), arg); return rc;
}
int
compile_server_client::include_file_or_directory (
const string &subdir, const string &path
)
{
vector<string> components;
string name;
int rc;
string rpath;
char *cpath = canonicalize_file_name (path.c_str ());
if (! cpath)
{
char cwd[PATH_MAX];
if (getcwd (cwd, sizeof (cwd)) == NULL)
{
rpath = path;
rc = 1;
goto done;
}
rpath = string (cwd) + "/" + path;
}
else
{
rpath = cpath;
free (cpath);
if (rpath == "/")
{
if (rpath != path)
clog << _F("%s resolves to %s\n", path.c_str (), rpath.c_str ());
clog << _F("Unable to send %s to the server\n", path.c_str ());
return 1;
}
name = client_tmpdir + "/" + subdir;
rc = create_dir (name.c_str ());
if (rc) goto done;
assert (rpath[0] == '/');
tokenize (rpath.substr (1), components, "/");
assert (components.size () >= 1);
unsigned i;
for (i = 0; i < components.size() - 1; ++i)
{
if (components[i].empty ())
continue; name += "/" + components[i];
rc = create_dir (name.c_str ());
if (rc) goto done;
}
assert (i == components.size () - 1);
name += "/" + components[i];
rc = symlink (rpath.c_str (), name.c_str ());
if (rc) goto done;
}
rc = add_package_arg (subdir + "/" + rpath.substr (1));
done:
if (rc != 0)
{
const char* e = strerror (errno);
clog << "ERROR: unable to add "
<< rpath
<< " to temp directory as "
<< name << ": " << e
<< endl;
}
return rc;
}
int
compile_server_client::add_localization_variables()
{
int rc;
string envVar;
string fname;
const set<string> &locVars = localization_variables();
set<string>::iterator it;
for (it = locVars.begin(); it != locVars.end(); it++)
{
char* var = getenv((*it).c_str());
if (var)
envVar += *it + "=" + (string)var + "\n";
}
fname = client_tmpdir + "/locale";
rc = write_to_file(fname, envVar);
return rc;
}
int
compile_server_client::package_request ()
{
client_zipfile = client_tmpdir + ".zip";
string cmd = "cd " + cmdstr_quoted(client_tmpdir) + " && zip -qr "
+ cmdstr_quoted(client_zipfile) + " *";
vector<string> sh_cmd;
sh_cmd.push_back("sh");
sh_cmd.push_back("-c");
sh_cmd.push_back(cmd);
int rc = stap_system (s.verbose, sh_cmd);
return rc;
}
int
compile_server_client::find_and_connect_to_server ()
{
vector<compile_server_info> specified_servers;
get_specified_server_info (s, specified_servers);
vector<compile_server_info> server_list;
for (vector<compile_server_info>::const_iterator i = specified_servers.begin ();
i != specified_servers.end ();
++i)
{
if (i->hasAddress() && i->fully_specified)
add_server_info (*i, server_list);
else
{
vector<compile_server_info> online_servers;
get_or_keep_online_server_info (s, online_servers, false);
if (! i->fully_specified)
{
get_or_keep_compatible_server_info (s, online_servers, true);
if (! pr_contains (s.privilege, pr_stapdev))
get_or_keep_signing_server_info (s, online_servers, true);
}
keep_common_server_info (*i, online_servers);
add_server_info (online_servers, server_list);
}
}
unsigned limit = server_list.size ();
if (limit == 0)
{
clog << _("Unable to find a suitable compile server. [man stap-server]") << endl;
vector<compile_server_info> online_servers;
get_or_keep_online_server_info (s, online_servers, false);
if (online_servers.empty ())
clog << _("No servers online to select from.") << endl;
else
{
clog << _("The following servers are online:") << endl;
clog << online_servers;
if (! specified_servers.empty ())
{
clog << _("The following servers were requested:") << endl;
clog << specified_servers;
}
else
{
string criteria = "online,trusted,compatible";
if (! pr_contains (s.privilege, pr_stapdev))
criteria += ",signer";
clog << _F("No servers matched the selection criteria of %s.", criteria.c_str())
<< endl;
}
}
return 1;
}
preferred_order (server_list);
int rc = compile_using_server (server_list);
if (rc == SUCCESS)
return 0;
if (rc == SERVER_CERT_EXPIRED_ERROR)
{
if (s.verbose >= 2)
clog << _("The server's certificate was expired. Trying again") << endl << flush;
sleep (2);
rc = compile_using_server (server_list);
if (rc == SUCCESS)
return 0; }
clog << _("Unable to connect to a server.") << endl;
if (s.verbose == 1)
{
clog << _("The following servers were tried:") << endl;
clog << server_list;
}
return 1; }
int
compile_server_client::compile_using_server (
vector<compile_server_info> &servers
)
{
s.NSPR_init ();
PR_SetError (SEC_ERROR_CA_CERT_INVALID, 0);
vector<string> dbs = private_ssl_dbs;
vector<string>::iterator i = dbs.end();
dbs.insert (i, public_ssl_dbs.begin (), public_ssl_dbs.end ());
int rc = GENERAL_ERROR; bool serverCertExpired = false;
for (i = dbs.begin (); i != dbs.end (); ++i)
{
if (! file_exists (*i))
continue;
#if 0#endif
const char *cert_dir = i->c_str ();
SECStatus secStatus = nssInit (cert_dir);
if (secStatus != SECSuccess)
{
continue; }
do {
const PRUint16 *cipher;
for (cipher = SSL_ImplementedCiphers; *cipher != 0; ++cipher)
SSL_CipherPolicySet(*cipher, SSL_ALLOWED);
} while (0);
SSL_ClearSessionCache ();
server_zipfile = s.tmpdir + "/server.zip";
for (vector<compile_server_info>::iterator j = servers.begin ();
j != servers.end ();
++j)
{
if (! j->hasAddress() || j->port == 0)
continue;
j->setAddressPort (j->port);
if (s.verbose >= 2)
clog << _F("Attempting SSL connection with %s\n"
" using certificates from the database in %s\n",
lex_cast(*j).c_str(), cert_dir);
rc = client_connect (*j, client_zipfile.c_str(), server_zipfile.c_str (),
NULL);
if (rc == SUCCESS)
{
s.winning_server = lex_cast(*j);
break; }
if (rc == SERVER_CERT_EXPIRED_ERROR)
{
serverCertExpired = true;
continue;
}
if (s.verbose >= 2)
{
clog << _(" Unable to connect: ");
nssError ();
if (isIPv6LinkLocal (j->address) && j->address.ipv6.scope_id == 0)
{
clog << _(" The address is an IPv6 link-local address with no scope specifier.")
<< endl;
}
}
}
SSL_ClearSessionCache ();
nssCleanup (cert_dir);
if (rc == SECSuccess)
break; }
if (rc != SUCCESS)
{
if (serverCertExpired)
rc = SERVER_CERT_EXPIRED_ERROR;
}
return rc;
}
int
compile_server_client::unpack_response ()
{
server_tmpdir = s.tmpdir + "/server";
vector<string> cmd;
cmd.push_back("unzip");
cmd.push_back("-qd");
cmd.push_back(server_tmpdir);
cmd.push_back(server_zipfile);
int rc = stap_system (s.verbose, cmd);
if (rc != 0)
{
clog << _F("Unable to unzip the server response '%s'\n", server_zipfile.c_str());
return rc;
}
string filename = server_tmpdir + "/version";
if (file_exists (filename))
::read_from_file (filename, server_version);
show_server_compatibility ();
glob_t globbuf;
string filespec = server_tmpdir + "/stap??????";
if (s.verbose >= 3)
clog << _F("Searching \"%s\"\n", filespec.c_str());
int r = glob(filespec.c_str (), 0, NULL, & globbuf);
if (r != GLOB_NOSPACE && r != GLOB_ABORTED && r != GLOB_NOMATCH)
{
if (globbuf.gl_pathc > 1)
{
clog << _("Incorrect number of files in server response") << endl;
rc = 1;
goto done;
}
assert (globbuf.gl_pathc == 1);
string dirname = globbuf.gl_pathv[0];
if (s.verbose >= 3)
clog << _(" found ") << dirname << endl;
filespec = dirname + "/*";
if (s.verbose >= 3)
clog << _F("Searching \"%s\"\n", filespec.c_str());
int r = glob(filespec.c_str (), GLOB_PERIOD, NULL, & globbuf);
if (r != GLOB_NOSPACE && r != GLOB_ABORTED && r != GLOB_NOMATCH)
{
unsigned prefix_len = dirname.size () + 1;
for (unsigned i = 0; i < globbuf.gl_pathc; ++i)
{
string oldname = globbuf.gl_pathv[i];
if (oldname.substr (oldname.size () - 2) == "/." ||
oldname.substr (oldname.size () - 3) == "/..")
continue;
string newname = s.tmpdir + "/" + oldname.substr (prefix_len);
if (s.verbose >= 3)
clog << _F(" found %s -- linking from %s", oldname.c_str(), newname.c_str());
rc = symlink (oldname.c_str (), newname.c_str ());
if (rc != 0)
{
clog << _F("Unable to link '%s' to '%s':%s\n",
oldname.c_str(), newname.c_str(), strerror(errno));
goto done;
}
}
}
}
if (server_version < "1.6")
{
cmd.clear();
cmd.push_back("sed");
cmd.push_back("-i");
cmd.push_back("/^Keeping temporary directory.*/ d");
cmd.push_back(server_tmpdir + "/stderr");
stap_system (s.verbose, cmd);
}
cmd.clear();
cmd.push_back("sed");
cmd.push_back("-i");
cmd.push_back("/^.*\\.ko$/ d");
cmd.push_back(server_tmpdir + "/stdout");
stap_system (s.verbose, cmd);
done:
globfree (& globbuf);
return rc;
}
int
compile_server_client::process_response ()
{
string filename = server_tmpdir + "/rc";
int stap_rc;
int rc = read_from_file (filename, stap_rc);
if (rc != 0)
return rc;
rc = stap_rc;
if (s.last_pass >= 4)
{
string filespec = s.tmpdir + "/*.ko";
if (s.verbose >= 3)
clog << _F("Searching \"%s\"\n", filespec.c_str());
glob_t globbuf;
int r = glob(filespec.c_str (), 0, NULL, & globbuf);
if (r != GLOB_NOSPACE && r != GLOB_ABORTED && r != GLOB_NOMATCH)
{
if (globbuf.gl_pathc > 1)
clog << _("Incorrect number of modules in server response") << endl;
else
{
assert (globbuf.gl_pathc == 1);
string modname = globbuf.gl_pathv[0];
if (s.verbose >= 3)
clog << _(" found ") << modname << endl;
if (! s.save_module)
{
vector<string> components;
tokenize (modname, components, "/");
s.module_name = components.back ();
s.module_name.erase(s.module_name.size() - 3);
}
string uprobes_ko;
if (server_version < "1.6")
uprobes_ko = s.tmpdir + "/server/uprobes.ko";
else
uprobes_ko = s.tmpdir + "/uprobes/uprobes.ko";
if (file_exists (uprobes_ko))
{
s.need_uprobes = true;
s.uprobes_path = uprobes_ko;
}
}
}
else if (s.have_script)
{
if (rc == 0)
{
clog << _("No module was returned by the server.") << endl;
rc = 1;
}
}
globfree (& globbuf);
}
string server_MOK_public_cert = s.tmpdir + "/server/" MOK_PUBLIC_CERT_NAME;
if (file_exists (server_MOK_public_cert))
{
string dst = MOK_PUBLIC_CERT_NAME;
copy_file (server_MOK_public_cert, dst, (s.verbose >= 3));
}
filename = server_tmpdir + "/stderr";
flush_to_stream (filename, clog);
filename = server_tmpdir + "/stdout";
flush_to_stream (filename, cout);
return rc;
}
int
compile_server_client::read_from_file (const string &fname, int &data)
{
errno = 0;
ifstream f (fname.c_str ());
if (! f.good ())
{
clog << _F("Unable to open file '%s' for reading: ", fname.c_str());
goto error;
}
errno = 0;
f >> data;
if (f.fail ())
{
clog << _F("Unable to read from file '%s': ", fname.c_str());
goto error;
}
return 0;
error:
if (errno)
clog << strerror (errno) << endl;
else
clog << _("unknown error") << endl;
return 1; }
template <class T>
int
compile_server_client::write_to_file (const string &fname, const T &data)
{
errno = 0;
ofstream f (fname.c_str ());
if (! f.good ())
{
clog << _F("Unable to open file '%s' for writing: ", fname.c_str());
goto error;
}
f << data;
errno = 0;
if (f.fail ())
{
clog << _F("Unable to write to file '%s': ", fname.c_str());
goto error;
}
return 0;
error:
if (errno)
clog << strerror (errno) << endl;
else
clog << _("unknown error") << endl;
return 1; }
int
compile_server_client::flush_to_stream (const string &fname, ostream &o)
{
errno = 0;
ifstream f (fname.c_str ());
if (! f.good ())
{
clog << _F("Unable to open file '%s' for reading: ", fname.c_str());
goto error;
}
while (1)
{
errno = 0;
int c = f.get();
if (f.eof ()) return 0; if (! f.good()) break;
o.put(c);
if (! o.good()) break;
}
error:
if (errno)
clog << strerror (errno) << endl;
else
clog << _("unknown error") << endl;
return 1; }
void
compile_server_client::show_server_compatibility () const
{
if (server_version < "1.6")
{
clog << _F("Server protocol version is %s\n", server_version.v);
clog << _("The server does not use localization information passed by the client\n");
}
}
static void
trust_already_in_place (
const compile_server_info &server,
const vector<compile_server_info> &server_list,
const string cert_db_path,
bool revoking
)
{
string purpose;
if (cert_db_path == signing_cert_db_path ())
purpose = _("as a module signer for all users");
else
{
purpose = _("as an SSL peer");
if (cert_db_path == global_ssl_cert_db_path ())
purpose += _(" for all users");
else
purpose += _(" for the current user");
}
unsigned limit = server_list.size ();
for (unsigned i = 0; i < limit; ++i)
{
if (server.certinfo != server_list[i].certinfo)
continue;
clog << server_list[i] << _(" is already ");
if (revoking)
clog << _("untrusted ") << purpose << endl;
else
clog << _("trusted ") << purpose << endl;
}
}
static void
add_server_trust (
systemtap_session &s,
const string &cert_db_path,
vector<compile_server_info> &server_list
)
{
vector<compile_server_info> already_trusted;
get_server_info_from_db (s, already_trusted, cert_db_path);
if (create_dir (cert_db_path.c_str (), 0755) != 0)
{
clog << _F("Unable to find or create the client certificate database directory %s: ", cert_db_path.c_str());
perror ("");
return;
}
vector<string> processed_certs;
s.NSPR_init ();
SECStatus secStatus = nssInit (cert_db_path.c_str (), 1);
if (secStatus != SECSuccess)
{
goto cleanup;
}
do {
const PRUint16 *cipher;
for (cipher = SSL_ImplementedCiphers; *cipher != 0; ++cipher)
SSL_CipherPolicySet(*cipher, SSL_ALLOWED);
} while (0);
SSL_ClearSessionCache ();
for (vector<compile_server_info>::iterator server = server_list.begin();
server != server_list.end ();
++server)
{
if (! server->certinfo.empty ())
{
if (find (processed_certs.begin (), processed_certs.end (),
server->certinfo) != processed_certs.end ())
continue;
processed_certs.push_back (server->certinfo);
if (find (already_trusted.begin (), already_trusted.end (), *server) !=
already_trusted.end ())
{
if (s.verbose >= 2)
trust_already_in_place (*server, server_list, cert_db_path, false);
continue;
}
}
if (! server->hasAddress() || server->port == 0)
continue;
server->setAddressPort (server->port);
int rc = client_connect (*server, NULL, NULL, "permanent");
if (rc != SUCCESS)
{
clog << _F("Unable to connect to %s", lex_cast(*server).c_str()) << endl;
nssError ();
if (isIPv6LinkLocal (server->address) && server->address.ipv6.scope_id == 0)
{
clog << _(" The address is an IPv6 link-local address with no scope specifier.")
<< endl;
}
}
}
cleanup:
SSL_ClearSessionCache ();
nssCleanup (cert_db_path.c_str ());
glob_t globbuf;
string filespec = cert_db_path + "/*.db";
if (s.verbose >= 3)
clog << _F("Searching \"%s\"\n", filespec.c_str());
int r = glob (filespec.c_str (), 0, NULL, & globbuf);
if (r != GLOB_NOSPACE && r != GLOB_ABORTED && r != GLOB_NOMATCH)
{
for (unsigned i = 0; i < globbuf.gl_pathc; ++i)
{
string filename = globbuf.gl_pathv[i];
if (s.verbose >= 3)
clog << _(" found ") << filename << endl;
if (chmod (filename.c_str (), 0644) != 0)
{
s.print_warning("Unable to change permissions on " + filename + ": ");
perror ("");
}
}
}
}
static void
revoke_server_trust (
systemtap_session &s,
const string &cert_db_path,
const vector<compile_server_info> &server_list
)
{
if (! file_exists (cert_db_path))
{
if (s.verbose >= 5)
{
clog << _F("Certificate database '%s' does not exist",
cert_db_path.c_str()) << endl;
for (vector<compile_server_info>::const_iterator server = server_list.begin();
server != server_list.end ();
++server)
trust_already_in_place (*server, server_list, cert_db_path, true);
}
return;
}
CERTCertDBHandle *handle;
PRArenaPool *tmpArena = NULL;
CERTCertList *certs = NULL;
CERTCertificate *db_cert;
vector<string> processed_certs;
const char *nickname;
s.NSPR_init ();
SECStatus secStatus = nssInit (cert_db_path.c_str (), 1);
if (secStatus != SECSuccess)
{
goto cleanup;
}
handle = CERT_GetDefaultCertDB();
tmpArena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
if (! tmpArena)
{
clog << _("Out of memory:");
nssError ();
goto cleanup;
}
nickname = server_cert_nickname ();
for (vector<compile_server_info>::const_iterator server = server_list.begin();
server != server_list.end ();
++server)
{
if (server->certinfo.empty ())
continue;
if (find (processed_certs.begin (), processed_certs.end (),
server->certinfo) != processed_certs.end ())
continue;
processed_certs.push_back (server->certinfo);
db_cert = PK11_FindCertFromNickname (nickname, NULL);
if (! db_cert)
{
if (s.verbose >= 2)
trust_already_in_place (*server, server_list, cert_db_path, true);
continue;
}
certs = CERT_CreateSubjectCertList (NULL, handle, & db_cert->derSubject,
PR_Now (), PR_FALSE);
CERT_DestroyCertificate (db_cert);
if (! certs)
{
clog << _F("Unable to query certificate database %s: ",
cert_db_path.c_str()) << endl;
PORT_SetError (SEC_ERROR_LIBRARY_FAILURE);
nssError ();
goto cleanup;
}
CERTCertListNode *node;
for (node = CERT_LIST_HEAD (certs);
! CERT_LIST_END (node, certs);
node = CERT_LIST_NEXT (node))
{
db_cert = node->cert;
string serialNumber = get_cert_serial_number (db_cert);
if (serialNumber != server->certinfo)
continue;
break;
}
if (CERT_LIST_END (node, certs))
{
if (s.verbose >= 2)
trust_already_in_place (*server, server_list, cert_db_path, true);
}
else
{
secStatus = SEC_DeletePermCertificate (db_cert);
if (secStatus != SECSuccess)
{
clog << _F("Unable to remove certificate from %s: ",
cert_db_path.c_str()) << endl;
nssError ();
}
}
CERT_DestroyCertList (certs);
certs = NULL;
}
cleanup:
assert(!certs);
if (tmpArena)
PORT_FreeArena (tmpArena, PR_FALSE);
nssCleanup (cert_db_path.c_str ());
}
static void
get_server_info_from_db (
systemtap_session &s,
vector<compile_server_info> &servers,
const string &cert_db_path
)
{
if (! file_exists (cert_db_path))
{
if (s.verbose >= 5)
clog << _F("Certificate database '%s' does not exist.",
cert_db_path.c_str()) << endl;
return;
}
s.NSPR_init ();
SECStatus secStatus = nssInit (cert_db_path.c_str ());
if (secStatus != SECSuccess)
{
return;
}
PRArenaPool *tmpArena = NULL;
CERTCertList *certs = get_cert_list_from_db (server_cert_nickname ());
if (! certs)
{
if (s.verbose >= 5)
clog << _F("No certificate found in database %s", cert_db_path.c_str ()) << endl;
goto cleanup;
}
tmpArena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
if (! tmpArena)
{
clog << _("Out of memory:");
nssError ();
goto cleanup;
}
for (CERTCertListNode *node = CERT_LIST_HEAD (certs);
! CERT_LIST_END (node, certs);
node = CERT_LIST_NEXT (node))
{
compile_server_info server_info;
CERTCertificate *db_cert = node->cert;
SECItem subAltName;
subAltName.data = NULL;
secStatus = CERT_FindCertExtension (db_cert,
SEC_OID_X509_SUBJECT_ALT_NAME,
& subAltName);
if (secStatus != SECSuccess || ! subAltName.data)
{
clog << _("Unable to find alt name extension on server certificate: ") << endl;
nssError ();
continue;
}
CERTGeneralName *nameList = CERT_DecodeAltNameExtension (tmpArena, & subAltName);
SECITEM_FreeItem(& subAltName, PR_FALSE);
if (! nameList)
{
clog << _("Unable to decode alt name extension on server certificate: ") << endl;
nssError ();
continue;
}
assert (nameList->type == certDNSName);
server_info.host_name = string ((const char *)nameList->name.other.data,
nameList->name.other.len);
server_info.certinfo = get_cert_serial_number (db_cert);
add_server_info (server_info, servers);
vector<compile_server_info> online_servers;
get_or_keep_online_server_info (s, online_servers, false);
keep_server_info_with_cert_and_port (s, server_info, online_servers);
add_server_info (online_servers, servers);
}
cleanup:
if (certs)
CERT_DestroyCertList (certs);
if (tmpArena)
PORT_FreeArena (tmpArena, PR_FALSE);
nssCleanup (cert_db_path.c_str ());
}
ostream &operator<< (ostream &s, const compile_server_info &i)
{
if (i.empty ())
return s;
s << " host=";
if (! i.host_name.empty ())
s << i.host_name;
else
s << "unknown";
s << " address=";
if (i.hasAddress())
{
PRStatus prStatus;
switch (i.address.raw.family)
{
case PR_AF_INET:
case PR_AF_INET6:
{
#define MAX_NETADDR_SIZE 46 char buf[MAX_NETADDR_SIZE];
prStatus = PR_NetAddrToString(& i.address, buf, sizeof (buf));
if (prStatus == PR_SUCCESS) {
s << buf;
break;
}
}
default:
s << "offline";
break;
}
}
else
s << "offline";
s << " port=";
if (i.port != 0)
s << i.port;
else
s << "unknown";
s << " sysinfo=\"";
if (! i.sysinfo.empty ())
s << i.sysinfo << '"';
else
s << "unknown\"";
s << " version=";
if (! i.version.empty ())
s << i.version;
else
s << "unknown";
s << " certinfo=\"";
if (! i.certinfo.empty ())
s << i.certinfo << '"';
else
s << "unknown\"";
if (! i.mok_fingerprints.empty ())
{
FIXME s << " mok_fingerprints=\"";
vector<string>::const_iterator it;
for (it = i.mok_fingerprints.begin (); it != i.mok_fingerprints.end ();
it++)
{
if (it != i.mok_fingerprints.begin ())
s << ", ";
s << *it;
}
s << "\"";
}
return s;
}
ostream &operator<< (ostream &s, const vector<compile_server_info> &v)
{
if (v.size () == 0 || (v.size () == 1 && v[0].empty()))
s << "No Servers" << endl;
else
{
for (unsigned i = 0; i < v.size(); ++i)
{
if (! v[i].empty())
s << v[i] << endl;
}
}
return s;
}
PRNetAddr &
copyNetAddr (PRNetAddr &x, const PRNetAddr &y)
{
PRUint32 saveScope = 0;
if (x.raw.family == PR_AF_INET6)
saveScope = x.ipv6.scope_id;
x = y;
if (saveScope != 0)
x.ipv6.scope_id = saveScope;
return x;
}
bool
operator== (const PRNetAddr &x, const PRNetAddr &y)
{
if (x.raw.family != y.raw.family)
return false;
switch (x.raw.family)
{
case PR_AF_INET6:
if (x.ipv6.scope_id != 0 && y.ipv6.scope_id != 0 && x.ipv6.scope_id != y.ipv6.scope_id)
return false; return memcmp (& x.ipv6.ip, & y.ipv6.ip, sizeof(x.ipv6.ip)) == 0;
case PR_AF_INET:
return x.inet.ip == y.inet.ip;
default:
break;
}
return false;
}
bool
operator!= (const PRNetAddr &x, const PRNetAddr &y)
{
return !(x == y);
}
static PRIPv6Addr &
copyAddress (PRIPv6Addr &PRin6, const in6_addr &in6)
{
assert (sizeof (PRin6) == sizeof (in6));
memcpy (& PRin6, & in6, sizeof (PRin6));
return PRin6;
}
static string
default_server_spec (const systemtap_session &s)
{
string working_string = "online,trusted,compatible";
if (! pr_contains (s.privilege, pr_stapdev))
working_string += ",signer";
return working_string;
}
static int
server_spec_to_pmask (const string &server_spec)
{
XXX string working_spec = server_spec;
vector<string> properties;
tokenize (working_spec, properties, ",");
int pmask = 0;
unsigned limit = properties.size ();
for (unsigned i = 0; i < limit; ++i)
{
const string &property = properties[i];
if (property.empty ())
continue;
if (property == "all")
{
pmask |= compile_server_all;
}
else if (property == "specified")
{
pmask |= compile_server_specified;
}
else if (property == "trusted")
{
pmask |= compile_server_trusted;
}
else if (property == "online")
{
pmask |= compile_server_online;
}
else if (property == "compatible")
{
pmask |= compile_server_compatible;
}
else if (property == "signer")
{
pmask |= compile_server_signer;
}
else
{
XXX clog << _F("WARNING: unsupported compile server property: %s", property.c_str())
<< endl;
}
}
return pmask;
}
void
query_server_status (systemtap_session &s)
{
unsigned limit = s.server_status_strings.size ();
for (unsigned i = 0; i < limit; ++i)
query_server_status (s, s.server_status_strings[i]);
}
static void
query_server_status (systemtap_session &s, const string &status_string)
{
string working_string = status_string;
if (working_string.empty ())
working_string = "specified";
TODO if (working_string == "specified" &&
(s.specified_servers.empty () ||
(s.specified_servers.size () == 1 && s.specified_servers[0].empty ())))
working_string = default_server_spec (s);
int pmask = server_spec_to_pmask (working_string);
vector<compile_server_info> raw_servers;
get_server_info (s, pmask, raw_servers);
vector<compile_server_info> servers;
get_all_server_info (s, servers);
keep_common_server_info (raw_servers, servers);
preferred_order (servers);
clog << _F("Systemtap Compile Server Status for '%s'", working_string.c_str()) << endl;
bool found = false;
unsigned limit = servers.size ();
for (unsigned i = 0; i < limit; ++i)
{
assert (! servers[i].empty ());
TODO if (servers[i].certinfo.empty ())
continue;
clog << servers[i] << endl;
found = true;
}
if (! found)
clog << _("No servers found") << endl;
}
void
manage_server_trust (systemtap_session &s)
{
if (s.server_trust_spec.empty ())
return;
vector<string>components;
tokenize (s.server_trust_spec, components, ",");
bool ssl = false;
bool signer = false;
bool revoke = false;
bool all_users = false;
bool no_prompt = false;
bool error = false;
for (vector<string>::const_iterator i = components.begin ();
i != components.end ();
++i)
{
if (*i == "ssl")
ssl = true;
else if (*i == "signer")
{
if (geteuid () != 0)
{
clog << _("Only root can specify 'signer' on --trust-servers") << endl;
error = true;
}
else
signer = true;
}
else if (*i == "revoke")
revoke = true;
else if (*i == "all-users")
{
if (geteuid () != 0)
{
clog << _("Only root can specify 'all-users' on --trust-servers") << endl;
error = true;
}
else
all_users = true;
}
else if (*i == "no-prompt")
no_prompt = true;
else
s.print_warning("Unrecognized server trust specification: " + *i);
}
if (error)
return;
s.NSPR_init ();
vector<compile_server_info> server_list;
get_specified_server_info (s, server_list, true);
unsigned limit = server_list.size ();
if (limit == 0)
{
clog << _("No servers identified for trust") << endl;
return;
}
if (! ssl && ! signer)
ssl = true;
ostringstream trustString;
if (ssl)
{
trustString << _("as an SSL peer");
if (all_users)
trustString << _(" for all users");
else
trustString << _(" for the current user");
}
if (signer)
{
if (ssl)
trustString << _(" and ");
trustString << _("as a module signer for all users");
}
if (no_prompt)
{
if (revoke)
clog << _("Revoking trust ");
else
clog << _("Adding trust ");
}
else
{
if (revoke)
clog << _("Revoke trust ");
else
clog << _("Add trust ");
}
clog << _F("in the following servers %s", trustString.str().c_str());
if (! no_prompt)
clog << '?';
clog << endl;
for (unsigned i = 0; i < limit; ++i)
clog << " " << server_list[i] << endl;
if (! no_prompt)
{
clog << "[y/N] " << flush;
string response;
cin >> response;
if (response[0] != 'y' && response [0] != 'Y')
{
clog << _("Server trust unchanged") << endl;
return;
}
}
string cert_db_path;
if (ssl)
{
if (all_users)
cert_db_path = global_ssl_cert_db_path ();
else
cert_db_path = private_ssl_cert_db_path ();
if (revoke)
revoke_server_trust (s, cert_db_path, server_list);
else
add_server_trust (s, cert_db_path, server_list);
}
if (signer)
{
cert_db_path = signing_cert_db_path ();
if (revoke)
revoke_server_trust (s, cert_db_path, server_list);
else
add_server_trust (s, cert_db_path, server_list);
}
}
static compile_server_cache*
cscache(systemtap_session& s)
{
if (!s.server_cache)
s.server_cache = new compile_server_cache();
return s.server_cache;
}
static void
get_server_info (
systemtap_session &s,
int pmask,
vector<compile_server_info> &servers
)
{
bool keep = false;
if (((pmask & compile_server_all)))
{
get_all_server_info (s, servers);
keep = true;
}
if ((pmask & compile_server_specified))
{
get_specified_server_info (s, servers);
keep = true;
}
if ((pmask & compile_server_online))
{
get_or_keep_online_server_info (s, servers, keep);
keep = true;
}
if ((pmask & compile_server_trusted))
{
get_or_keep_trusted_server_info (s, servers, keep);
keep = true;
}
if ((pmask & compile_server_signer))
{
get_or_keep_signing_server_info (s, servers, keep);
keep = true;
}
if ((pmask & compile_server_compatible))
{
get_or_keep_compatible_server_info (s, servers, keep);
keep = true;
}
}
static void
get_all_server_info (
systemtap_session &s,
vector<compile_server_info> &servers
)
{
vector<compile_server_info>& all_servers = cscache(s)->all_servers;
if (all_servers.empty ())
{
get_or_keep_online_server_info (s, all_servers, false);
get_or_keep_trusted_server_info (s, all_servers, false);
get_or_keep_signing_server_info (s, all_servers, false);
if (s.verbose >= 4)
{
clog << _("All known servers:") << endl;
clog << all_servers;
}
}
add_server_info (all_servers, servers);
}
static void
get_default_server_info (
systemtap_session &s,
vector<compile_server_info> &servers
)
{
if (s.verbose >= 3)
clog << _("Using the default servers") << endl;
vector<compile_server_info>& default_servers = cscache(s)->default_servers;
if (default_servers.empty ())
{
int pmask = server_spec_to_pmask (default_server_spec (s));
get_server_info (s, pmask, default_servers);
if (s.verbose >= 3)
{
clog << _("Default servers are:") << endl;
clog << default_servers;
}
}
add_server_info (default_servers, servers);
}
static bool
isPort (const char *pstr, compile_server_info &server_info)
{
errno = 0;
char *estr;
unsigned long p = strtoul (pstr, & estr, 10);
if (errno != 0 || *estr != '\0' || p > USHRT_MAX)
{
clog << _F("Invalid port number specified: %s", pstr) << endl;
return false;
}
server_info.port = p;
server_info.fully_specified = true;
return true;
}
static bool
isIPv6 (const string &server, compile_server_info &server_info)
{
assert (! server.empty());
string ip;
string::size_type portIx;
if (server[0] == '[')
{
string::size_type endBracket = server.find (']');
if (endBracket == string::npos)
return false; ip = server.substr (1, endBracket - 1);
portIx = endBracket + 1;
}
else
{
ip = server;
portIx = string::npos;
}
unsigned empty = 0;
vector<string> components;
tokenize_full (ip, components, ":");
if (components.size() > 8)
return false;
string interface;
for (unsigned i = 0; i < components.size(); ++i)
{
if (components[i].empty())
{
if (++empty > 1)
return false; }
if (i == components.size() - 1)
{
size_t ix = components[i].find ('%');
if (ix != string::npos)
{
interface = components[i].substr(ix);
components[i] = components[i].substr(0, ix);
}
}
unsigned j;
for (j = 0; j < components[i].size(); ++j)
{
if (components[i][j] != '0')
break;
}
if (components[i].size() - j > 4)
return false; for (; j < components[i].size(); ++j)
{
if (! isxdigit (components[i][j]))
return false; }
}
if (! empty && components.size() != 8)
return false;
PRStatus prStatus = PR_StringToNetAddr (ip.c_str(), & server_info.address);
if (prStatus != PR_SUCCESS)
return false;
if (portIx != string::npos)
{
string port = server.substr (portIx);
if (port.size() != 0)
{
if (port.size() < 2 || port[0] != ':')
return false;
port = port.substr (1);
if (! isPort (port.c_str(), server_info))
return false; }
}
else
server_info.port = 0;
return true; }
static bool
isIPv4 (const string &server, compile_server_info &server_info)
{
assert (! server.empty());
vector<string> components;
tokenize (server, components, ":");
if (components.size() > 2)
return false;
string addr;
string port;
if (components.size() <= 1)
addr = server;
else {
addr = components[0];
port = components[1];
}
components.clear ();
tokenize (addr, components, ".");
if (components.size() != 4)
return false;
for (unsigned i = 0; i < components.size(); ++i)
{
if (components[i].empty())
return false; errno = 0;
char *estr;
long p = strtol (components[i].c_str(), & estr, 10);
if (errno != 0 || *estr != '\0' || p < 0 || p > 255)
return false; }
PRStatus prStatus = PR_StringToNetAddr (addr.c_str(), & server_info.address);
if (prStatus != PR_SUCCESS)
return false;
if (! port.empty ()) {
if (! isPort (port.c_str(), server_info))
return false; }
else
server_info.port = 0;
return true; }
static bool
isCertSerialNumber (const string &server, compile_server_info &server_info)
{
assert (! server.empty());
string host = server;
vector<string> components;
tokenize (host, components, ":");
switch (components.size ())
{
case 6:
if (! isPort (components.back().c_str(), server_info))
return false; host = host.substr (0, host.find_last_of (':'));
case 5:
server_info.certinfo = host;
break;
default:
return false; }
return true; }
static bool
isDomain (const string &server, compile_server_info &server_info)
{
assert (! server.empty());
string host = server;
vector<string> components;
tokenize (host, components, ":");
switch (components.size ())
{
case 2:
if (! isPort (components.back().c_str(), server_info))
return false; host = host.substr (0, host.find_last_of (':'));
case 1:
server_info.host_name = host;
break;
default:
return false; }
return true;
}
static void
get_specified_server_info (
systemtap_session &s,
vector<compile_server_info> &servers,
bool no_default
)
{
vector<compile_server_info>& specified_servers = cscache(s)->specified_servers;
if (specified_servers.empty ())
{
specified_servers.push_back (compile_server_info ());
if (s.specified_servers.empty ())
{
if (s.verbose >= 3)
clog << _("No servers specified") << endl;
if (! no_default)
get_default_server_info (s, specified_servers);
}
else
{
unsigned num_specified_servers = s.specified_servers.size ();
for (unsigned i = 0; i < num_specified_servers; ++i)
{
string &server = s.specified_servers[i];
if (server.empty ())
{
if (s.verbose >= 3)
clog << _("No servers specified") << endl;
if (! no_default)
get_default_server_info (s, specified_servers);
continue;
}
compile_server_info server_info;
vector<compile_server_info> resolved_servers;
if (isIPv6 (server, server_info) || isIPv4 (server, server_info) ||
isCertSerialNumber (server, server_info))
{
resolved_servers.push_back (server_info);
}
else if (isDomain (server, server_info))
{
resolve_host (s, server_info, resolved_servers);
}
else
{
clog << _F("Invalid server specification for --use-server: %s", server.c_str())
<< endl;
continue;
}
vector<compile_server_info> known_servers;
vector<compile_server_info> new_servers;
for (vector<compile_server_info>::iterator i = resolved_servers.begin();
i != resolved_servers.end();
++i)
{
if (i->fully_specified)
add_server_info (*i, new_servers);
else {
if (known_servers.empty ())
get_all_server_info (s, known_servers);
vector<compile_server_info> matched_servers = known_servers;
keep_common_server_info (*i, matched_servers);
if (! matched_servers.empty())
add_server_info (matched_servers, new_servers);
else if (i->isComplete ())
add_server_info (*i, new_servers);
else if (s.verbose >= 3)
clog << _("Incomplete server spec: ") << *i << endl;
}
}
if (s.verbose >= 3)
{
clog << _F("Servers matching %s: ", server.c_str()) << endl;
clog << new_servers;
}
if (! new_servers.empty())
add_server_info (new_servers, specified_servers);
} }
if (s.verbose >= 2)
{
clog << _("All specified servers:") << endl;
clog << specified_servers;
}
}
add_server_info (specified_servers, servers);
}
static void
get_or_keep_trusted_server_info (
systemtap_session &s,
vector<compile_server_info> &servers,
bool keep
)
{
if (keep && servers.empty ())
return;
vector<compile_server_info>& trusted_servers = cscache(s)->trusted_servers;
if (trusted_servers.empty ())
{
trusted_servers.push_back (compile_server_info ());
string cert_db_path = private_ssl_cert_db_path ();
get_server_info_from_db (s, trusted_servers, cert_db_path);
cert_db_path = global_ssl_cert_db_path ();
get_server_info_from_db (s, trusted_servers, cert_db_path);
if (s.verbose >= 5)
{
clog << _("All servers trusted as ssl peers:") << endl;
clog << trusted_servers;
}
}
if (keep)
{
keep_common_server_info (trusted_servers, servers);
}
else
{
add_server_info (trusted_servers, servers);
}
}
static void
get_or_keep_signing_server_info (
systemtap_session &s,
vector<compile_server_info> &servers,
bool keep
)
{
if (keep && servers.empty ())
return;
vector<compile_server_info>& signing_servers = cscache(s)->signing_servers;
if (signing_servers.empty ())
{
signing_servers.push_back (compile_server_info ());
string cert_db_path = signing_cert_db_path ();
get_server_info_from_db (s, signing_servers, cert_db_path);
if (s.verbose >= 5)
{
clog << _("All servers trusted as module signers:") << endl;
clog << signing_servers;
}
}
if (keep)
{
keep_common_server_info (signing_servers, servers);
}
else
{
add_server_info (signing_servers, servers);
}
}
static void
get_or_keep_compatible_server_info (
systemtap_session &s,
vector<compile_server_info> &servers,
bool keep
)
{
#if HAVE_AVAHI
if (keep && servers.empty ())
return;
vector<compile_server_info> online_servers;
get_or_keep_online_server_info (s, online_servers, false);
if (keep)
keep_common_server_info (online_servers, servers);
else
add_server_info (online_servers, servers);
for (unsigned i = 0; i < servers.size (); )
{
assert (! servers[i].empty ());
if (servers[i].sysinfo != s.kernel_release + " " + s.architecture)
{
servers.erase (servers.begin () + i);
continue;
}
if (! s.mok_fingerprints.empty ())
{
if (servers[i].mok_fingerprints.empty ())
{
servers.erase (servers.begin () + i);
continue;
}
vector<string>::const_iterator it;
bool mok_found = false;
for (it = s.mok_fingerprints.begin(); it != s.mok_fingerprints.end(); it++)
{
if (find(servers[i].mok_fingerprints.begin(),
servers[i].mok_fingerprints.end(), *it)
!= servers[i].mok_fingerprints.end ())
{
mok_found = true;
break;
}
}
if (! mok_found)
{
servers.erase (servers.begin () + i);
continue;
}
}
++i;
}
#else if (s.verbose >= 2)
clog << _("Unable to detect server compatibility without avahi") << endl;
if (keep)
servers.clear ();
#endif
}
static void
keep_server_info_with_cert_and_port (
systemtap_session &,
const compile_server_info &server,
vector<compile_server_info> &servers
)
{
assert (! server.certinfo.empty ());
for (unsigned i = 0; i < servers.size (); )
{
if (servers[i].empty ())
{
++i;
continue;
}
if (servers[i].certinfo == server.certinfo &&
(servers[i].port == 0 || server.port == 0 ||
servers[i].port == server.port))
{
if (servers[i].port == 0)
{
servers[i].port = server.port;
servers[i].fully_specified = server.fully_specified;
}
++i;
continue;
}
servers.erase (servers.begin () + i);
}
}
static void
resolve_host (
systemtap_session& s,
compile_server_info &server,
vector<compile_server_info> &resolved_servers
)
{
vector<resolved_host>& cached_hosts = cscache(s)->resolved_hosts[server.host_name];
if (cached_hosts.empty ())
{
const char *lookup_name = server.host_name.c_str();
if (s.verbose >= 6)
clog << _F("Looking up %s", lookup_name) << endl;
struct addrinfo hints;
memset(& hints, 0, sizeof (hints));
hints.ai_family = AF_UNSPEC; struct addrinfo *addr_info = 0;
int rc = getaddrinfo (lookup_name, NULL, & hints, & addr_info);
if (rc != 0)
{
if (s.verbose >= 6)
clog << _F("%s not found: %s", lookup_name, gai_strerror (rc)) << endl;
}
else
{
assert (addr_info);
for (const struct addrinfo *ai = addr_info; ai != NULL; ai = ai->ai_next)
{
PRNetAddr new_address;
if (ai->ai_family == AF_INET)
{
struct sockaddr_in *ip = (struct sockaddr_in *)ai->ai_addr;
new_address.inet.family = PR_AF_INET;
new_address.inet.ip = ip->sin_addr.s_addr;
}
else if (ai->ai_family == AF_INET6)
{
struct sockaddr_in6 *ip = (struct sockaddr_in6 *)ai->ai_addr;
new_address.ipv6.family = PR_AF_INET6;
new_address.ipv6.scope_id = ip->sin6_scope_id;
copyAddress (new_address.ipv6.ip, ip->sin6_addr);
}
else
continue;
char hbuf[NI_MAXHOST];
int status = getnameinfo (ai->ai_addr, ai->ai_addrlen, hbuf, sizeof (hbuf), NULL, 0,
NI_NAMEREQD | NI_IDN);
if (status != 0)
hbuf[0] = '\0';
cached_hosts.push_back(resolved_host(hbuf, new_address));
}
}
if (addr_info)
freeaddrinfo (addr_info); }
if (cached_hosts.empty())
add_server_info (server, resolved_servers);
else {
vector<compile_server_info> new_servers;
for (vector<resolved_host>::const_iterator it = cached_hosts.begin();
it != cached_hosts.end(); ++it)
{
compile_server_info new_server = server;
if (it->address.raw.family == AF_INET)
{
new_server.address.inet.family = PR_AF_INET;
new_server.address.inet.ip = it->address.inet.ip;
}
else {
new_server.address.ipv6.family = PR_AF_INET6;
new_server.address.ipv6.scope_id = it->address.ipv6.scope_id;
new_server.address.ipv6.ip = it->address.ipv6.ip;
}
if (!it->host_name.empty())
new_server.host_name = it->host_name;
add_server_info (new_server, new_servers);
}
if (s.verbose >= 6)
{
clog << _F("%s resolves to:", server.host_name.c_str()) << endl;
clog << new_servers;
}
add_server_info (new_servers, resolved_servers);
}
}
#if HAVE_AVAHI
struct browsing_context {
AvahiSimplePoll *simple_poll;
AvahiClient *client;
vector<compile_server_info> *servers;
};
static string
get_value_from_avahi_string_list (AvahiStringList *strlst, const string &key)
{
AvahiStringList *p = avahi_string_list_find (strlst, key.c_str ());
if (p == NULL)
{
return "";
}
char *k, *v;
int rc = avahi_string_list_get_pair(p, &k, &v, NULL);
if (rc < 0 || v == NULL)
{
avahi_free (k);
return "";
}
string value = v;
avahi_free (k);
avahi_free (v);
return value;
}
static void
get_values_from_avahi_string_list (AvahiStringList *strlst, const string &key,
vector<string> &value_vector)
{
AvahiStringList *p;
value_vector.clear();
p = avahi_string_list_find (strlst, key.c_str ());
for (; p != NULL; p = avahi_string_list_get_next(p))
{
char *k, *v;
int rc = avahi_string_list_get_pair(p, &k, &v, NULL);
if (rc < 0 || v == NULL)
{
avahi_free (k);
break;
}
value_vector.push_back(v);
avahi_free (k);
avahi_free (v);
}
return;
}
extern "C"
void resolve_callback(
AvahiServiceResolver *r,
AvahiIfIndex interface,
AvahiProtocol protocol,
AvahiResolverEvent event,
const char *name,
const char *type,
const char *domain,
const char *host_name,
const AvahiAddress *address,
uint16_t port,
AvahiStringList *txt,
AvahiLookupResultFlags ,
AVAHI_GCC_UNUSED void* userdata)
{
PRStatus prStatus;
assert(r);
const browsing_context *context = (browsing_context *)userdata;
vector<compile_server_info> *servers = context->servers;
switch (event) {
case AVAHI_RESOLVER_FAILURE:
clog << _F("Failed to resolve service '%s' of type '%s' in domain '%s': %s",
name, type, domain,
avahi_strerror(avahi_client_errno(avahi_service_resolver_get_client(r)))) << endl;
break;
case AVAHI_RESOLVER_FOUND: {
compile_server_info info;
char a[AVAHI_ADDRESS_STR_MAX];
avahi_address_snprint(a, sizeof(a), address);
prStatus = PR_StringToNetAddr (a, & info.address);
if (prStatus != PR_SUCCESS) {
clog << _F("Invalid address '%s' from avahi", a) << endl;
break;
}
if (protocol == AVAHI_PROTO_INET6) {
info.address.ipv6.family = PR_AF_INET6;
info.address.ipv6.scope_id = interface;
info.port = port;
}
else if (protocol == AVAHI_PROTO_INET) {
info.address.inet.family = PR_AF_INET;
info.port = port;
}
else
break;
info.host_name = host_name;
info.sysinfo = get_value_from_avahi_string_list (txt, "sysinfo");
info.certinfo = get_value_from_avahi_string_list (txt, "certinfo");
info.version = get_value_from_avahi_string_list (txt, "version");
if (info.version.empty ())
info.version = "1.0";
get_values_from_avahi_string_list(txt, "mok_info",
info.mok_fingerprints);
add_server_info (info, *servers);
break;
}
default:
break;
}
avahi_service_resolver_free(r);
}
extern "C"
void browse_callback(
AvahiServiceBrowser *b,
AvahiIfIndex interface,
AvahiProtocol protocol,
AvahiBrowserEvent event,
const char *name,
const char *type,
const char *domain,
AVAHI_GCC_UNUSED AvahiLookupResultFlags flags,
void* userdata) {
browsing_context *context = (browsing_context *)userdata;
AvahiClient *c = context->client;
AvahiSimplePoll *simple_poll = context->simple_poll;
assert(b);
switch (event) {
case AVAHI_BROWSER_FAILURE:
clog << _F("Avahi browse failed: %s",
avahi_strerror(avahi_client_errno(avahi_service_browser_get_client(b))))
<< endl;
avahi_simple_poll_quit(simple_poll);
break;
case AVAHI_BROWSER_NEW:
if (!(avahi_service_resolver_new(c, interface, protocol, name, type, domain,
AVAHI_PROTO_UNSPEC, (AvahiLookupFlags)0, resolve_callback, context))) {
clog << _F("Failed to resolve service '%s': %s",
name, avahi_strerror(avahi_client_errno(c))) << endl;
}
break;
case AVAHI_BROWSER_REMOVE:
case AVAHI_BROWSER_ALL_FOR_NOW:
case AVAHI_BROWSER_CACHE_EXHAUSTED:
break;
}
}
extern "C"
void client_callback(AvahiClient *c, AvahiClientState state, AVAHI_GCC_UNUSED void * userdata) {
assert(c);
browsing_context *context = (browsing_context *)userdata;
AvahiSimplePoll *simple_poll = context->simple_poll;
if (state == AVAHI_CLIENT_FAILURE) {
clog << _F("Avahi Server connection failure: %s", avahi_strerror(avahi_client_errno(c))) << endl;
avahi_simple_poll_quit(simple_poll);
}
}
extern "C"
void timeout_callback(AVAHI_GCC_UNUSED AvahiTimeout *e, AVAHI_GCC_UNUSED void *userdata) {
browsing_context *context = (browsing_context *)userdata;
AvahiSimplePoll *simple_poll = context->simple_poll;
avahi_simple_poll_quit(simple_poll);
}
#endif
static void
get_or_keep_online_server_info (
systemtap_session &s,
vector<compile_server_info> &servers,
bool keep
)
{
if (keep && servers.empty ())
return;
vector<compile_server_info>& online_servers = cscache(s)->online_servers;
if (online_servers.empty ())
{
online_servers.push_back (compile_server_info ());
#if HAVE_AVAHI
vector<compile_server_info> avahi_servers;
AvahiClient *client = NULL;
AvahiServiceBrowser *sb = NULL;
AvahiSimplePoll *simple_poll;
if (!(simple_poll = avahi_simple_poll_new()))
{
clog << _("Failed to create Avahi simple poll object") << endl;
goto fail;
}
browsing_context context;
context.simple_poll = simple_poll;
context.servers = & avahi_servers;
int error;
client = avahi_client_new (avahi_simple_poll_get (simple_poll),
(AvahiClientFlags)0,
client_callback, & context, & error);
if (! client)
{
clog << _F("Failed to create Avahi client: %s",
avahi_strerror(error)) << endl;
goto fail;
}
context.client = client;
if (!(sb = avahi_service_browser_new (client, AVAHI_IF_UNSPEC,
AVAHI_PROTO_UNSPEC, "_stap._tcp",
NULL, (AvahiLookupFlags)0,
browse_callback, & context)))
{
clog << _F("Failed to create Avahi service browser: %s",
avahi_strerror(avahi_client_errno(client))) << endl;
goto fail;
}
struct timeval tv;
avahi_simple_poll_get(simple_poll)->timeout_new(
avahi_simple_poll_get(simple_poll),
avahi_elapse_time(&tv, 1000/2, 0),
timeout_callback,
& context);
avahi_simple_poll_loop(simple_poll);
if (s.verbose >= 6)
{
clog << _("Avahi reports the following servers online:") << endl;
clog << avahi_servers;
}
add_server_info (avahi_servers, online_servers);
fail:
if (client) {
avahi_client_free(client);
}
if (simple_poll)
avahi_simple_poll_free(simple_poll);
#else if (s.verbose >= 2)
clog << _("Unable to detect online servers without avahi") << endl;
#endif
if (s.verbose >= 5)
{
clog << _("All online servers:") << endl;
clog << online_servers;
}
}
if (keep)
{
keep_common_server_info (online_servers, servers);
}
else
{
add_server_info (online_servers, servers);
}
}
static void
add_server_info (
const compile_server_info &info, vector<compile_server_info>& target
)
{
if (info.empty ())
return;
bool found = false;
for (vector<compile_server_info>::iterator i = target.begin ();
i != target.end ();
++i)
{
if (info == *i)
{
merge_server_info (info, *i);
found = true;
}
}
if (! found)
target.push_back (info);
}
static void
add_server_info (
const vector<compile_server_info> &source,
vector<compile_server_info> &target
)
{
for (vector<compile_server_info>::const_iterator i = source.begin ();
i != source.end ();
++i)
{
add_server_info (*i, target);
}
}
static void
keep_common_server_info (
const compile_server_info &info_to_keep,
vector<compile_server_info> &filtered_info
)
{
assert (! info_to_keep.empty ());
for (unsigned i = 0; i < filtered_info.size (); )
{
if (filtered_info[i].empty ())
{
++i;
continue;
}
if (info_to_keep == filtered_info[i])
{
merge_server_info (info_to_keep, filtered_info[i]);
++i;
continue;
}
filtered_info.erase (filtered_info.begin () + i);
continue;
}
}
static void
keep_common_server_info (
const vector<compile_server_info> &info_to_keep,
vector<compile_server_info> &filtered_info
)
{
for (unsigned i = 0; i < filtered_info.size (); )
{
if (filtered_info[i].empty ())
{
++i;
continue;
}
bool found = false;
for (unsigned j = 0; j < info_to_keep.size (); ++j)
{
if (filtered_info[i] == info_to_keep[j])
{
merge_server_info (info_to_keep[j], filtered_info[i]);
found = true;
}
}
if (found)
++i;
else
filtered_info.erase (filtered_info.begin () + i);
}
}
static void
merge_server_info (
const compile_server_info &source,
compile_server_info &target
)
{
if (! source.host_name.empty())
target.host_name = source.host_name;
assert (! target.hasAddress () || ! source.hasAddress () || source.address == target.address);
if (source.hasAddress ())
copyNetAddr (target.address, source.address);
if (target.port == 0)
{
target.port = source.port;
target.fully_specified = source.fully_specified;
}
if (target.sysinfo.empty ())
target.sysinfo = source.sysinfo;
if (target.version.empty ())
target.version = source.version;
if (target.certinfo.empty ())
target.certinfo = source.certinfo;
}
#if 0
#endif
#endif