diff --git a/docker/netrave-protohandler/Gemfile b/docker/netrave-protohandler/Gemfile index 75b89af..14edeeb 100644 --- a/docker/netrave-protohandler/Gemfile +++ b/docker/netrave-protohandler/Gemfile @@ -3,4 +3,6 @@ source 'https://rubygems.org' gem 'async' +gem 'openssl' gem 'sequel' +gem 'socket' diff --git a/docker/netrave-protohandler/netrave_protohandler.rb b/docker/netrave-protohandler/netrave_protohandler.rb index 82b217e..7ee1eb1 100644 --- a/docker/netrave-protohandler/netrave_protohandler.rb +++ b/docker/netrave-protohandler/netrave_protohandler.rb @@ -4,31 +4,51 @@ require 'socket' require 'async' require 'sequel' require 'openssl' +require 'securerandom' # Set up the database DB = Sequel.sqlite # In-memory database + +# Create processors table +DB.create_table :processors do + primary_key :proto_handler_id + String :uuid + String :domain + Integer :port +end + +# Create orchestrators table +DB.create_table :orchestrators do + primary_key :id + String :domain + Integer :port +end + +# Create blacklist table +DB.create_table :blacklist do + primary_key :id + String :uuid +end + processors = DB[:processors] +orchestrators = DB[:orchestrators] +blacklist = DB[:blacklist] listen_ip = ENV['LISTEN_IP'] || '0.0.0.0' listen_port = ENV['LISTEN_PORT'] || 3080 server = TCPServer.new(listen_ip, listen_port) -# This hash will store the connections to the consumers -connections = {} +# This hash will store the recently connected UUIDs with their last validation timestamp +recently_connected = {} def create_socket(ip, port) # rubocop:disable Metrics/MethodLength - # If the IP address is set to "loopback", replace it with the actual loopback IP address ip = '127.0.0.1' if ip.downcase == 'loopback' - if ip =~ Resolv::IPv4::Regex - # If the IP address is an IPv4 address, create an unencrypted socket TCPSocket.new(ip, port) else - # If the IP address is a domain name, create an SSL socket ssl_context = OpenSSL::SSL::SSLContext.new ssl_context.verify_mode = OpenSSL::SSL::VERIFY_PEER - tcp_socket = TCPSocket.new(ip, port) ssl_socket = OpenSSL::SSL::SSLSocket.new(tcp_socket, ssl_context) ssl_socket.sync_close = true @@ -37,43 +57,41 @@ def create_socket(ip, port) # rubocop:disable Metrics/MethodLength end end -Async do - loop do - Async::Task.new do - client = server.accept - - begin - while (line = client.gets) - # Here we handle each line of input from the client - handle_input(line, connections) - end - ensure - # This code will be executed when the fiber is finished, regardless of whether an exception was raised - client.close - Async::Task.current.stop # Stop the current task - end - end - end -end - -def handle_input(line, connections) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength - # Split the line into command and args +def handle_input(line, connections, consumer_uuids, blacklisted_ips, client, processors, blacklist, recently_connected) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity, Metrics/ParameterLists command, *args = line.split + uuid = args.shift if command != 'REGISTER' + + # Check if the UUID is blacklisted + if blacklisted_ips.include?(client.peeraddr[3]) || blacklist.where(uuid:).count.positive? + client.puts 'TERMINATE' + return + end + + # Check if the UUID is in the recently connected cache + if recently_connected[uuid] && Time.now - recently_connected[uuid] < 120 + # UUID is recently connected and within 2 minutes, no need to re-validate + elsif command != 'REGISTER' && (uuid.nil? || !consumer_uuids.values.include?(uuid)) + # UUID is not recently connected or is invalid, add to blacklist + blacklisted_ips.add(client.peeraddr[3]) + blacklist.insert(uuid:) + client.puts 'ERROR Unrecognized UUID' + return + else + # UUID is valid, update the recently connected cache + recently_connected[uuid] = Time.now + end case command - when 'NEW_PROCESSOR' - # Handle new processor registration - id, domain, port = args - - # Create a new connection to the processor + when 'REGISTER' + id = processors.max(:proto_handler_id).to_i + 1 + domain, port = args + uuid = SecureRandom.uuid + client.puts "UUID #{uuid}" + consumer_uuids[id] = uuid Async do processor_connection = create_socket(domain, port) - - # Store the connection in the hash connections[id] = processor_connection - - # Add the processor to the database - processors.insert(consumer_id: id, ip: domain, port: port) + processors.insert(proto_handler_id: id, uuid:, domain:, port:) end when 'REQUEST' @@ -126,3 +144,33 @@ def handle_input(line, connections) # rubocop:disable Metrics/AbcSize, Metrics/M puts "Unknown command: #{command}" end end + +def register_orchestrator(line, orchestrators) + domain, port = line.split + id = orchestrators.max(:id).to_i + 1 + orchestrators.insert(id:, domain:, port:) + puts "Orchestrator registered with domain: #{domain}, port: #{port}" +end + +Async do + loop do + Async::Task.new do + client = server.accept + + begin + line = client.gets + # Determine if the connection is from an orchestrator for registration + if is_orchestrator_registration?(line) # Define this method to identify orchestrator registration + register_orchestrator(line, orchestrators) + else + # Here we handle each line of input from the client + handle_input(line, client, processors, blacklist, recently_connected) # Pass the variables here + end + ensure + # This code will be executed when the fiber is finished, regardless of whether an exception was raised + client.close + Async::Task.current.stop + end + end + end +end