HEX
Server: LiteSpeed
System: Linux CentOS-79-64-minimal 3.10.0-1160.119.1.el7.x86_64 #1 SMP Tue Jun 4 14:43:51 UTC 2024 x86_64
User: vishn3436 (5293)
PHP: 8.0.15
Disabled: NONE
Upload Files
File: //scripts/script-server/src/tests/test_utils.py
import json
import os
import shutil
import stat
import threading
import uuid
from copy import copy
from unittest.case import TestCase
from unittest.mock import MagicMock

import utils.file_utils as file_utils
import utils.os_utils as os_utils
from auth.auth_base import Authenticator
from execution.process_base import ProcessWrapper
from model.script_config import ConfigModel, ParameterModel
from react.properties import ObservableDict
from utils import audit_utils

temp_folder = 'tests_temp'
_original_env = {}


def create_file(filepath, *, overwrite=False, text='test text'):
    if not os.path.exists(temp_folder):
        os.makedirs(temp_folder)

    filename = os.path.basename(filepath)
    folder = os.path.join(temp_folder, os.path.dirname(filepath))
    if not os.path.exists(folder):
        os.makedirs(folder)

    file_path = os.path.join(folder, filename)
    if os.path.exists(file_path) and not overwrite:
        raise Exception('File ' + file_path + ' already exists')

    file_utils.write_file(file_path, text)

    return file_path


def create_files(names, dir=None):
    for name in names:
        if dir is not None:
            create_file(os.path.join(dir, name))
        else:
            create_file(name)


def create_dir(dir_path):
    if not os.path.exists(temp_folder):
        os.makedirs(temp_folder)

    full_path = os.path.join(temp_folder, dir_path)
    if not os.path.exists(full_path):
        os.makedirs(full_path)

    return full_path


def setup():
    if os.path.exists(temp_folder):
        _rmtree(temp_folder)

    os.makedirs(temp_folder)


def cleanup():
    if os.path.exists(temp_folder):
        _rmtree(temp_folder)

    os_utils.reset_os()

    for key, value in _original_env.items():
        if value is None:
            del os.environ[key]
        else:
            os.environ[key] = value

    _original_env.clear()


def _rmtree(folder):
    exception = None

    def on_rm_error(func, path, exc_info):
        try:
            os.chmod(path, stat.S_IWRITE | stat.S_IEXEC | stat.S_IREAD)
            os.remove(path)
        except Exception as e:
            print('Failed to remove path ' + path + ': ' + str(e))
            nonlocal exception
            if exception is None:
                exception = e

    shutil.rmtree(folder, onerror=on_rm_error)
    if exception:
        raise exception


def set_linux():
    os_utils.set_linux()


def set_win():
    os_utils.set_win()


def mock_object():
    return type('', (), {})()


def write_script_config(conf_object, filename, config_folder=None):
    if config_folder is None:
        config_folder = os.path.join(temp_folder, 'runners')
    file_path = os.path.join(config_folder, filename + '.json')

    config_json = json.dumps(conf_object)
    file_utils.write_file(file_path, config_json)
    return file_path


def create_script_param_config(
        param_name,
        *,
        type=None,
        default=None,
        required=None,
        secure=None,
        param=None,
        env_var=None,
        no_value=None,
        constant=None,
        multiselect_separator=None,
        multiselect_argument_type=None,
        min=None,
        max=None,
        allowed_values=None,
        values_script=None,
        file_dir=None,
        file_recursive=None,
        file_type=None,
        file_extensions=None,
        excluded_files=None,
        same_arg_param=None,
        values_script_shell=None):
    conf = {'name': param_name}

    if type is not None:
        conf['type'] = type

    if values_script is not None:
        conf['values'] = {'script': values_script}
        if values_script_shell is not None:
            conf['values']['shell'] = values_script_shell

    if default is not None:
        conf['default'] = default

    if required is not None:
        conf['required'] = required

    if secure is not None:
        conf['secure'] = secure

    if param is not None:
        conf['param'] = param

    if env_var is not None:
        conf['env_var'] = env_var

    if no_value is not None:
        conf['no_value'] = no_value

    if constant is not None:
        conf['constant'] = constant

    if multiselect_separator is not None:
        conf['separator'] = multiselect_separator

    if multiselect_argument_type is not None:
        conf['multiselect_argument_type'] = multiselect_argument_type

    if min is not None:
        conf['min'] = min

    if max is not None:
        conf['max'] = max

    if allowed_values is not None:
        conf['values'] = list(allowed_values)

    if file_dir is not None:
        conf['file_dir'] = file_dir

    if file_recursive is not None:
        conf['file_recursive'] = file_recursive

    if file_extensions is not None:
        conf['file_extensions'] = file_extensions

    if file_type is not None:
        conf['file_type'] = file_type

    if excluded_files is not None:
        conf['excluded_files'] = excluded_files

    if same_arg_param is not None:
        conf['same_arg_param'] = same_arg_param

    return conf


def create_config_model(name, *,
                        config=None,
                        username='user1',
                        audit_name='127.0.0.1',
                        path=None,
                        parameters=None,
                        parameter_values=None,
                        script_command='ls',
                        output_files=None,
                        requires_terminal=None,
                        schedulable=True):
    result_config = {}

    if config:
        result_config.update(config)

    result_config['name'] = name

    if parameters is not None:
        result_config['parameters'] = parameters

    if path is None:
        path = name

    if output_files is not None:
        result_config['output_files'] = output_files

    if requires_terminal is not None:
        result_config['requires_terminal'] = requires_terminal

    if schedulable is not None:
        result_config['scheduling'] = {'enabled': schedulable}

    result_config['script_path'] = script_command

    model = ConfigModel(result_config, path, username, audit_name)
    if parameter_values is not None:
        model.set_all_param_values(model)

    return model


def create_parameter_model(name=None,
                           *,
                           type=None,
                           values_script=None,
                           default=None,
                           required=None,
                           secure=None,
                           param=None,
                           env_var=None,
                           no_value=None,
                           constant=None,
                           multiselect_separator=None,
                           multiselect_argument_type=None,
                           min=None,
                           max=None,
                           allowed_values=None,
                           username='user1',
                           audit_name='127.0.0.1',
                           all_parameters=None,
                           file_dir=None,
                           file_recursive=None,
                           other_param_values: ObservableDict = None,
                           values_script_shell=None):
    config = create_script_param_config(
        name,
        type=type,
        values_script=values_script,
        default=default,
        required=required,
        secure=secure,
        param=param,
        env_var=env_var,
        no_value=no_value,
        constant=constant,
        multiselect_separator=multiselect_separator,
        multiselect_argument_type=multiselect_argument_type,
        min=min,
        max=max,
        allowed_values=allowed_values,
        file_dir=file_dir,
        file_recursive=file_recursive,
        values_script_shell=values_script_shell)

    if all_parameters is None:
        all_parameters = []

    return ParameterModel(config,
                          username,
                          audit_name,
                          lambda: all_parameters,
                          other_param_values=other_param_values)


def create_simple_parameter_configs(names):
    return {name: {'name': name} for name in names}


def create_parameter_model_from_config(config,
                                       *,
                                       username='user1',
                                       audit_name='127.0.0.1',
                                       working_dir=None,
                                       all_parameters=None):
    if all_parameters is None:
        all_parameters = []

    if config is None:
        config = {}

    return ParameterModel(config, username, audit_name, all_parameters, working_dir=working_dir)


def create_audit_names(ip=None, auth_username=None, proxy_username=None, hostname=None):
    result = {}
    if ip is not None:
        result[audit_utils.IP] = ip
    if auth_username is not None:
        result[audit_utils.AUTH_USERNAME] = auth_username
    if proxy_username is not None:
        result[audit_utils.PROXIED_USERNAME] = proxy_username
    if hostname is not None:
        result[audit_utils.HOSTNAME] = hostname
    return result


def set_env_value(key, value):
    if key not in _original_env:
        if key in os.environ:
            _original_env[key] = value
        else:
            _original_env[key] = None

    os.environ[key] = value


def assert_large_dict_equal(expected, actual, testcase):
    if len(expected) < 20 and len(actual) < 20:
        testcase.assertEqual(expected, actual)
        return

    if expected == actual:
        return

    diff_expected = {}
    diff_actual = {}
    too_large_diff = False

    all_keys = set()
    all_keys.update(expected.keys())
    all_keys.update(actual.keys())
    for key in all_keys:
        expected_value = expected.get(key)
        actual_value = actual.get(key)

        if expected_value == actual_value:
            continue

        diff_expected[key] = expected_value
        diff_actual[key] = actual_value

        if len(diff_expected) >= 50:
            too_large_diff = True
            break

    message = 'Showing only different elements'
    if too_large_diff:
        message += ' (limited to 50)'

    testcase.assertEqual(diff_expected, diff_actual, message)


def wait_observable_close_notification(observable, timeout):
    close_condition = threading.Event()
    observable.subscribe_on_close(lambda: close_condition.set())
    close_condition.wait(timeout)


def mock_request_handler(*, arguments: dict = None, method='GET', headers=None):
    if headers is None:
        headers = {}

    request_handler = mock_object()

    def get_argument(arg_name):
        if arguments is None:
            return None
        return arguments.get(arg_name)

    request_handler.get_argument = get_argument

    request_handler.request = mock_object()
    request_handler.request.method = method
    request_handler.request.headers = headers

    return request_handler


def assert_dir_files(expected_files, dir_path, test_case: TestCase):
    expected_files_sorted = sorted(copy(expected_files))
    actual_files = sorted(os.listdir(dir_path))

    test_case.assertSequenceEqual(expected_files_sorted, actual_files)


class _MockProcessWrapper(ProcessWrapper):
    def __init__(self, executor, command, working_directory, env_variables):
        super().__init__(command, working_directory, env_variables)

        self.exit_code = None
        self.finished = False
        self.process_id = int.from_bytes(uuid.uuid1().bytes, byteorder='big')
        self.finish_condition = threading.Condition()

    def get_process_id(self):
        return self.process_id

    # method for tests
    def finish(self, exit_code):
        if self.is_finished():
            raise Exception('Cannot finish a script twice')
        self.__finish(exit_code)

    # method for tests
    def write_output(self, output):
        self._write_script_output(output)

    def stop(self):
        self.__finish(9)

    def kill(self):
        self.__finish(15)

    def __finish(self, exit_code):
        if self.finished:
            return

        with self.finish_condition:
            self.exit_code = exit_code
            self.finished = True
            self.output_stream.close()
            self.finish_condition.notify_all()

        self.notify_finish_thread.join()

    def is_finished(self):
        return self.finished

    def get_return_code(self):
        return self.exit_code

    def pipe_process_output(self):
        pass

    def start_execution(self, command, working_directory):
        pass

    def wait_finish(self):
        with self.finish_condition:
            while not self.finished:
                self.finish_condition.wait(0.01)

    def write_to_input(self, value):
        pass


class AnyUserAuthorizer:
    def is_allowed_in_app(self, user_id):
        return True

    def is_allowed(self, user_id, allowed_users):
        return True

    def is_admin(self, user_id):
        return True


class _IdGeneratorMock:
    def __init__(self) -> None:
        super().__init__()
        self.generated_ids = []
        self._next_id = 123

    def next_id(self):
        id = str(self._next_id)
        self._next_id += 1
        self.generated_ids.append(id)
        return id


class AsyncMock(MagicMock):
    async def __call__(self, *args, **kwargs):
        return super(AsyncMock, self).__call__(*args, **kwargs)


class MockAuthenticator(Authenticator):
    def authenticate(self, request_handler):
        return request_handler.request.remote_ip