diff --git a/.gitignore b/.gitignore index 796927f5..eb333ed5 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ __pycache__ /pkg/openapiart.go /pkg/httpapi _debug_bin +.env +.env2 \ No newline at end of file diff --git a/openapiart/common.go b/openapiart/common.go index 2a6b99f5..acdce537 100644 --- a/openapiart/common.go +++ b/openapiart/common.go @@ -161,48 +161,46 @@ func validationResult() error { return nil } -func validateMac(mac string) error { +func validateMac(mac string, path string) error { macSlice := strings.Split(mac, ":") if len(macSlice) != 6 { - return fmt.Errorf(fmt.Sprintf("Invalid Mac address %s", mac)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid mac string, instead of `%s`", path, mac)) } - octInd := []string{"0th", "1st", "2nd", "3rd", "4th", "5th"} - for ind, val := range macSlice { + for _, val := range macSlice { num, err := strconv.ParseUint(val, 16, 32) if err != nil || num > 255 { - return fmt.Errorf(fmt.Sprintf("Invalid Mac address at %s octet in %s mac", octInd[ind], mac)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid mac string, instead of `%s`", path, mac)) } } return nil } -func validateIpv4(ip string) error { +func validateIpv4(ip string, path string) error { ipSlice := strings.Split(ip, ".") if len(ipSlice) != 4 { - return fmt.Errorf(fmt.Sprintf("Invalid Ipv4 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv4 string, instead of `%s`", path, ip)) } - octInd := []string{"1st", "2nd", "3rd", "4th"} - for ind, val := range ipSlice { + for _, val := range ipSlice { num, err := strconv.ParseUint(val, 10, 32) if err != nil || num > 255 { - return fmt.Errorf(fmt.Sprintf("Invalid Ipv4 address at %s octet in %s ipv4", octInd[ind], ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv4 string, instead of `%s`", path, ip)) } } return nil } -func validateIpv6(ip string) error { +func validateIpv6(ip string, path string) error { ip = strings.Trim(ip, " \t") if strings.Count(ip, " ") > 0 || strings.Count(ip, ":") > 7 || strings.Count(ip, "::") > 1 || strings.Count(ip, ":::") > 0 || strings.Count(ip, ":") == 0 { - return fmt.Errorf(fmt.Sprintf("Invalid ipv6 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } if (string(ip[0]) == ":" && string(ip[:2]) != "::") || (string(ip[len(ip)-1]) == ":" && string(ip[len(ip)-2:]) != "::") { - return fmt.Errorf(fmt.Sprintf("Invalid ipv6 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } if strings.Count(ip, "::") == 0 && strings.Count(ip, ":") != 7 { - return fmt.Errorf(fmt.Sprintf("Invalid ipv6 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } if ip == "::" { return nil @@ -217,69 +215,69 @@ func validateIpv6(ip string) error { r := strings.NewReplacer("::", ":0:") ip = r.Replace(ip) } - octInd := []string{"1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"} ipSlice := strings.Split(ip, ":") - for ind, val := range ipSlice { + for _, val := range ipSlice { num, err := strconv.ParseUint(val, 16, 64) if err != nil || num > 65535 { - return fmt.Errorf(fmt.Sprintf("Invalid Ipv6 address at %s octet in %s ipv6", octInd[ind], ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } } return nil } -func validateHex(hex string) error { +func validateHex(hex string, path string) error { matched, err := regexp.MatchString(`^[0-9a-fA-F]+$|^0[x|X][0-9a-fA-F]+$`, hex) if err != nil || !matched { - return fmt.Errorf(fmt.Sprintf("Invalid hex value %s", hex)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid hex string, instead of %s", path, hex)) } return nil } -func validateSlice(valSlice []string, sliceType string) error { +func validateSlice(valSlice []string, sliceType string, path string) error { indices := []string{} var err error for i, val := range valSlice { if sliceType == "mac" { - err = validateMac(val) + err = validateMac(val, path) } else if sliceType == "ipv4" { - err = validateIpv4(val) + err = validateIpv4(val, path) } else if sliceType == "ipv6" { - err = validateIpv6(val) + err = validateIpv6(val, path) } else if sliceType == "hex" { - err = validateHex(val) + err = validateHex(val, path) } else { - return fmt.Errorf(fmt.Sprintf("Invalid slice type received <%s>", sliceType)) + return fmt.Errorf(fmt.Sprintf("invalid slice type received <%s>", sliceType)) } if err != nil { - indices = append(indices, fmt.Sprintf("%d", i)) + indices = append(indices, + fmt.Sprintf("value of `%s[%d]` must be a valid %s string, instead of `%s`", path, i, sliceType, val)) } } if len(indices) > 0 { return fmt.Errorf( - fmt.Sprintf("Invalid %s addresses at indices %s", sliceType, strings.Join(indices, ",")), + strings.Join(indices, "\n"), ) } return nil } -func validateMacSlice(mac []string) error { - return validateSlice(mac, "mac") +func validateMacSlice(mac []string, path string) error { + return validateSlice(mac, "mac", path) } -func validateIpv4Slice(ip []string) error { - return validateSlice(ip, "ipv4") +func validateIpv4Slice(ip []string, path string) error { + return validateSlice(ip, "ipv4", path) } -func validateIpv6Slice(ip []string) error { - return validateSlice(ip, "ipv6") +func validateIpv6Slice(ip []string, path string) error { + return validateSlice(ip, "ipv6", path) } -func validateHexSlice(hex []string) error { - return validateSlice(hex, "hex") +func validateHexSlice(hex []string, path string) error { + return validateSlice(hex, "hex", path) } diff --git a/openapiart/common.py b/openapiart/common.py index e40061f1..5b0f9f22 100644 --- a/openapiart/common.py +++ b/openapiart/common.py @@ -1,6 +1,7 @@ import importlib import logging import json +from markupsafe import string import yaml import requests import urllib3 @@ -216,39 +217,62 @@ def _decode(self, dict_object): class OpenApiValidator(object): __slots__ = () + _validation_errors = [] def __init__(self): - pass + pass + + def _append_error(self, msg): + self._validation_errors.append(msg) + + def _get_validation_errors(self): + return self._validation_errors + + def _clear_errors(self): + import platform + if '2.7' in platform.python_version().rsplit(".", 1)[0]: + del self._validation_errors[:] + else: + self._validation_errors.clear() - def validate_mac(self, mac): + def validate_mac(self, path, mac): + msg = "value of `{}` must be a valid mac address, instead of `{}`".format(path, mac) if mac is None or not isinstance(mac, (str, unicode)) or mac.count(" ") != 0: - return False + self._append_error(msg) try: if len(mac) != 17: - return False - return all([0 <= int(oct, 16) <= 255 for oct in mac.split(":")]) + self._append_error(msg) + if all([0 <= int(oct, 16) <= 255 for oct in mac.split(":")]) is False: + self._append_error(msg) except Exception: - return False + self._append_error(msg) - def validate_ipv4(self, ip): + def validate_ipv4(self, path, ip): + msg = "value of `{}` must be a valid ipv4 address, instead of `{}`".format(path, ip) if ip is None or not isinstance(ip, (str, unicode)) or ip.count(" ") != 0: - return False + self._append_error(msg) if len(ip.split(".")) != 4: - return False + self._append_error(msg) try: - return all([0 <= int(oct) <= 255 for oct in ip.split(".", 3)]) + if all([0 <= int(oct) <= 255 for oct in ip.split(".", 3)]) is False: + self._append_error(msg) except Exception: - return False + self._append_error(msg) - def validate_ipv6(self, ip): + def validate_ipv6(self, path, ip): + msg = "value of `{}` must be a valid ipv6 address, instead of `{}`".format(path, ip) if ip is None or not isinstance(ip, (str, unicode)): + self._append_error(msg) return False ip = ip.strip() if ip.count(" ") > 0 or ip.count(":") > 7 or ip.count("::") > 1 or ip.count(":::") > 0: + self._append_error(msg) return False if (ip[0] == ":" and ip[:2] != "::") or (ip[-1] == ":" and ip[-2:] != "::"): + self._append_error(msg) return False if ip.count("::") == 0 and ip.count(":") != 7: + self._append_error(msg) return False if ip == "::": return True @@ -259,68 +283,90 @@ def validate_ipv6(self, ip): else: ip = ip.replace("::", ":0:") try: - return all([True if (0 <= int(oct, 16) <= 65535) and (1 <= len(oct) <= 4) else False for oct in ip.split(":")]) + verdict = all([ + True if (0 <= int(oct, 16) <= 65535) and (1 <= len(oct) <= 4) else False for oct in ip.split(":") + ]) + if verdict is False: + self._append_error(msg) except Exception: - return False + self._append_error(msg) - def validate_hex(self, hex): + def validate_hex(self, path, hex): + msg = "value of `{}` must be a valid hex string, instead of `{}`".format(path, hex) if hex is None or not isinstance(hex, (str, unicode)): - return False + self._append_error(msg) try: int(hex, 16) return True except Exception: - return False + self._append_error(msg) - def validate_integer(self, value, min, max): + def validate_integer(self, path, value): if value is None or not isinstance(value, int): - return False - if value < 0: - return False - if min is not None and value < min: - return False - if max is not None and value > max: - return False - return True + self._append_error("value of `{}` must be a valid int type, instead of `{}`".format( + path, value + )) + + def validate_min_max(self, path, value, min, max): + if isinstance(value, str): + value = len(value) + if (min is not None and value < min) or (max is not None and value > max): + self._append_error("length of `{}` must be in the range of [{}, {}], instead of `{}`".format( + path, + min if min is not None else "", + max if max is not None else "", + value + )) - def validate_float(self, value): - return isinstance(value, (int, float)) + def validate_float(self, path, value): + if isinstance(value, (int, float)) is False: + self._append_error("value of `{}` must be a valid float type, instead of `{}`".format( + path, value + )) - def validate_string(self, value, min_length, max_length): + def validate_string(self, path, value): if value is None or not isinstance(value, (str, unicode)): - return False - if min_length is not None and len(value) < min_length: - return False - if max_length is not None and len(value) > max_length: - return False - return True + self._append_error("value of `{}` must be a valid string type, instead of `{}`".format( + path, value + )) - def validate_bool(self, value): - return isinstance(value, bool) + def validate_bool(self, path, value): + if isinstance(value, bool) is False: + self._append_error("value of `{}` must be a valid bool type, instead of `{}`".format( + path, value + )) - def validate_list(self, value, itemtype, min, max, min_length, max_length): + def validate_list(self, path, value, itemtype, min, max): if value is None or not isinstance(value, list): return False v_obj = getattr(self, "validate_{}".format(itemtype), None) if v_obj is None: raise AttributeError("{} is not a valid attribute".format(itemtype)) - v_obj_lst = [] - for item in value: - if itemtype == "integer": - v_obj_lst.append(v_obj(item, min, max)) - elif itemtype == "string": - v_obj_lst.append(v_obj(item, min_length, max_length)) + for ind, item in enumerate(value): + if itemtype in ["integer", "string", "float"]: + v_obj(path + "[{}]".format(ind), item) + self.validate_min_max(path, item, min, max) else: - v_obj_lst.append(v_obj(item)) - return v_obj_lst + v_obj(path + "[{}]".format(ind), item) - def validate_binary(self, value): - if value is None or not isinstance(value, (str, unicode)): - return False - return all([True if int(bin) == 0 or int(bin) == 1 else False for bin in value]) + def validate_binary(self, path, value): + if value is None or not isinstance(value, (str, unicode)) or \ + all([True if int(bin) == 0 or int(bin) == 1 else False for bin in value]) is False: + self._append_error("value of `{}` must be a valid binary string, instead of `{}`".format( + path, value + )) - def types_validation(self, value, type_, err_msg, itemtype=None, min=None, max=None, min_length=None, max_length=None): - type_map = {int: "integer", str: "string", float: "float", bool: "bool", list: "list", "int64": "integer", "int32": "integer", "double": "float"} + def types_validation(self, value, type_, path, itemtype=None, min=None, max=None): + type_map = { + int: "integer", + str: "string", + float: "float", + bool: "bool", + list: "list", + "int64": "integer", + "int32": "integer", + "double": "float" + } if type_ in type_map: type_ = type_map[type_] if itemtype is not None and itemtype in type_map: @@ -329,36 +375,15 @@ def types_validation(self, value, type_, err_msg, itemtype=None, min=None, max=N if v_obj is None: msg = "{} is not a valid or unsupported format".format(type_) raise TypeError(msg) - if type_ == "list": - verdict = v_obj(value, itemtype, min, max, min_length, max_length) - if all(verdict) is True: - return - err_msg = "{} \n {} are not valid".format(err_msg, [value[index] for index, item in enumerate(verdict) if item is False]) - verdict = False - elif type_ == "integer": - verdict = v_obj(value, min, max) - if verdict is True: - return - min_max = "" - if min is not None: - min_max = ", expected min {}".format(min) - if max is not None: - min_max = min_max + ", expected max {}".format(max) - err_msg = "{} \n got {} of type {} {}".format(err_msg, value, type(value), min_max) - elif type_ == "string": - verdict = v_obj(value, min_length, max_length) - if verdict is True: - return - msg = "" - if min_length is not None: - msg = ", expected min {}".format(min_length) - if max_length is not None: - msg = msg + ", expected max {}".format(max_length) - err_msg = "{} \n got {} of type {} {}".format(err_msg, value, type(value), msg) - else: - verdict = v_obj(value) - if verdict is False: - raise TypeError(err_msg) + v_obj(path, value) if type_ != "list" else v_obj(path, value, itemtype, min, max) + if type_ in ["integer", "string", "float"]: + self.validate_min_max(path, value, min, max) + + def _raise_validation(self): + errors = "\n".join(self._validation_errors) + if len(self._get_validation_errors()) > 0: + self._clear_errors() + raise Exception(errors) class OpenApiObject(OpenApiBase, OpenApiValidator): @@ -371,6 +396,9 @@ class OpenApiObject(OpenApiBase, OpenApiValidator): """ __slots__ = ("_properties", "_parent", "_choice") + + _JSON_NAME = "" + _DEFAULTS = {} _TYPES = {} _REQUIRED = [] @@ -397,16 +425,21 @@ def _has_choice(self, name): return True else: return False + + def _is_enum_valid(self, name, value): + if name in self._TYPES and "enum" in self._TYPES[name]: + if value in self._TYPES[name]["enum"]: + return True + else: + return False + return True def _get_property(self, name, default_value=None, parent=None, choice=None): if name in self._properties and self._properties[name] is not None: return self._properties[name] if isinstance(default_value, type) is True: self._set_choice(name) - if "_choice" in default_value.__slots__: - self._properties[name] = default_value(parent=parent, choice=choice) - else: - self._properties[name] = default_value(parent=parent) + self._properties[name] = default_value(parent=parent) if "_DEFAULTS" in dir(self._properties[name]) and "choice" in self._properties[name]._DEFAULTS: getattr(self._properties[name], self._properties[name]._DEFAULTS["choice"]) else: @@ -422,17 +455,19 @@ def _set_property(self, name, value, choice=None): self._set_choice(name) self._properties[name] = self._DEFAULTS[name] else: - self._set_choice(name) - self._properties[name] = value + if not self._is_enum_valid(name, value): + self._append_error("{} is not a valid enum for property {}".format(value, name)) + else: + self._set_choice(name) + self._properties[name] = value if self._parent is not None and self._choice is not None and value is not None: self._parent._set_property("choice", self._choice) def _encode(self): """Helper method for serialization""" + self._validate(self._JSON_NAME) output = {} - self._validate_required() for key, value in self._properties.items(): - self._validate_types(key, value) if isinstance(value, (OpenApiObject, OpenApiIter)): output[key] = value._encode() elif value is not None: @@ -448,7 +483,7 @@ def _decode(self, obj): if isinstance(property_value, dict): child = self._get_child_class(property_name) if "choice" in child[1]._TYPES and "_parent" in child[1].__slots__: - property_value = child[1](self, property_name)._decode(property_value) + property_value = child[1](self)._decode(property_value) elif "_parent" in child[1].__slots__: property_value = child[1](self)._decode(property_value) else: @@ -467,8 +502,7 @@ def _decode(self, obj): if "format" in self._TYPES[property_name] and self._TYPES[property_name]["format"] == "int64": property_value = int(property_value) self._properties[property_name] = property_value - self._validate_types(property_name, property_value) - self._validate_required() + self._validate(self._JSON_NAME) return self def _get_child_class(self, property_name, is_property_list=False): @@ -499,35 +533,37 @@ def clone(self): """Creates a deep copy of the current object""" return self.__deepcopy__(None) - def _validate_required(self): + def _validate_required(self, path): """Validates the required properties are set Use getattr as it will set any defaults prior to validating """ if getattr(self, "_REQUIRED", None) is None: return for name in self._REQUIRED: - if getattr(self, name, None) is None: - msg = "{} is a mandatory property of {}" " and should not be set to None".format( - name, - self.__class__, + if self._properties.get(name) is None: + msg = "required field `{}.{}` must not be empty".format( + path, name ) - raise ValueError(msg) + self._append_error(msg) - def _validate_types(self, property_name, property_value): + def _validate_types(self, path, property_name, property_value): common_data_types = [list, str, int, float, bool] if property_name not in self._TYPES: - # raise ValueError("Invalid Property {}".format(property_name)) return details = self._TYPES[property_name] - if property_value is None and property_name not in self._DEFAULTS and property_name not in self._REQUIRED: + if property_value is None: return if "enum" in details and property_value not in details["enum"]: - msg = "property {} shall be one of these" " {} enum, but got {} at {}" - raise TypeError(msg.format(property_name, details["enum"], property_value, self.__class__)) + msg = "enum field `{}` must be one of {}, instead of `{}`".format( + path, details["enum"], property_value + ) + self._append_error(msg) if details["type"] in common_data_types and "format" not in details: - msg = "property {} shall be of type {} at {}".format(property_name, details["type"], self.__class__) - self.types_validation(property_value, details["type"], msg, details.get("itemtype"), details.get("minimum"), details.get("maximum"), - details.get("minLength"), details.get("maxLength")) + self.types_validation( + property_value, details["type"], path, details.get("itemtype"), + details.get("minimum", details.get("minLength")), + details.get("maximum", details.get("maxLength")) + ) if details["type"] not in common_data_types: class_name = details["type"] @@ -535,18 +571,35 @@ def _validate_types(self, property_name, property_value): module = importlib.import_module(self.__module__) object_class = getattr(module, class_name) if not isinstance(property_value, object_class): - msg = "property {} shall be of type {}," " but got {} at {}" - raise TypeError(msg.format(property_name, class_name, type(property_value), self.__class__)) + msg = "value of `{}` must be a valid {} type, instead of `{}`" + self._append_error( + msg.format(path, class_name, type(property_value)) + ) if "format" in details: - msg = "Invalid {} format, expected {} at {}".format(property_value, details["format"], self.__class__) _type = details["type"] if details["type"] is list else details["format"] - self.types_validation(property_value, _type, msg, details["format"], details.get("minimum"), details.get("maximum"), - details.get("minLength"), details.get("maxLength")) - - def validate(self): - self._validate_required() + self.types_validation( + property_value, _type, path, details["format"], + details.get("minimum", details.get("minLength")), + details.get("maximum", details.get("maxLength")) + ) + + def _validate(self, path, skip_exception=False): + self._validate_required(path) for key, value in self._properties.items(): - self._validate_types(key, value) + if isinstance(value, OpenApiObject): + value._validate(path + ".%s" % key, True) + elif isinstance(value, OpenApiIter): + for ind, item in enumerate(value): + if not isinstance(item, OpenApiObject): + continue + item._validate(path + ".%s[%d]" % (key, ind), True) + self._validate_types(path + ".%s" % (key), key, value) + if skip_exception: + return self._validation_errors + self._raise_validation() + + def validate(self): + return self._validate(self._JSON_NAME) def get(self, name, with_default=False): """ @@ -641,14 +694,18 @@ def append(self, item): """Append an item to the end of OpenApiIter TBD: type check, raise error on mismatch """ - if isinstance(item, OpenApiObject) is False: - raise Exception("Item is not an instance of OpenApiObject") + self._instanceOf(item) self._add(item) return self def clear(self): del self._items[:] self._index = -1 + + def set(self, index, item): + self._instanceOf(item) + self._items[index] = item + return self def _encode(self): return [item._encode() for item in self._items] @@ -672,3 +729,6 @@ def __str__(self): def __eq__(self, other): return self.__str__() == other.__str__() + + def _instanceOf(self, item): + raise NotImplementedError("validating an OpenApiIter object is not supported") diff --git a/openapiart/generator.py b/openapiart/generator.py index 33ea52d9..9243ab23 100644 --- a/openapiart/generator.py +++ b/openapiart/generator.py @@ -705,10 +705,19 @@ def _get_external_field_name(self, openapi_name): external += "_" return external + def _small_first_char(self, words): + if isinstance(words, list): + ret = [] + for wrd in words: + ret.append(wrd[0].lower() + wrd[1:]) + return ret + return words[0].lower() + words[1:] + def _write_openapi_object(self, ref, choice_method_name=None): schema_object = self._get_object_from_ref(ref) ref_name = ref.split("/")[-1] class_name = ref_name.replace(".", "") + json_name = "_".join(self._small_first_char(ref_name.split("."))) if class_name in self._generated_classes: return self._generated_classes.append(class_name) @@ -723,7 +732,7 @@ def _write_openapi_object(self, ref, choice_method_name=None): if "choice" in self._get_choice_names(schema_object): slots.append("'_choice'") self._write(1, "__slots__ = (%s)" % ",".join(slots)) - self._write() + self._write(1, '_JSON_NAME = "%s"' % json_name) # write _TYPES definition # TODO: this func won't detect whether $ref for a given property is @@ -1069,6 +1078,14 @@ def _write_openapilist_special_methods( self._write(1, "def next(self):") self._write(2, "# type: () -> %s" % contained_class_name) self._write(2, "return self._next()") + self._write() + self._write(1, "def _instanceOf(self, item):") + self._write(2, "if not isinstance(item, %s):" % (contained_class_name)) + self._write( + 3, + 'raise Exception("Item is not an instance of %s")' + % (contained_class_name), + ) def _write_factory_method( self, diff --git a/openapiart/openapiartgo.py b/openapiart/openapiartgo.py index bbb90c24..772691ac 100644 --- a/openapiart/openapiartgo.py +++ b/openapiart/openapiartgo.py @@ -65,6 +65,7 @@ def __init__(self): self.description = None self.method_description = None self.interface_fields = [] + self.schema_raw_name = None def isOptional(self, property_name): if self.schema_object is None: @@ -79,6 +80,7 @@ def isOptional(self, property_name): class FluentField(object): def __init__(self): self.name = None + self.schema_name = None self.description = None self.getter_method_description = None self.setter_method_description = None @@ -265,6 +267,7 @@ def _write_common_code(self): self._write('import "github.com/ghodss/yaml"') self._write('import "google.golang.org/protobuf/encoding/protojson"') self._write('import "google.golang.org/protobuf/proto"') + self._write('import "google.golang.org/grpc/credentials/insecure"') go_pkg_fp = self._fp go_pkg_filename = self._filename self._filename = os.path.normpath( @@ -428,6 +431,9 @@ def _build_api_interface(self): new.schema_name = self._get_schema_object_name_from_ref( ref[0].value ) + new.schema_raw_name = self._get_schema_json_name_from_ref( + ref[0].value + ) new.schema_object = self._get_schema_object_from_ref( ref[0].value ) @@ -593,7 +599,8 @@ def _build_api_interface(self): if api.grpcClient == nil {{ ctx, cancelFunc := context.WithTimeout(context.Background(), api.grpc.dialTimeout) defer cancelFunc() - conn, err := grpc.DialContext(ctx, api.grpc.location, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.DialContext( + ctx, api.grpc.location, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil {{ return err }} @@ -880,6 +887,7 @@ def _build_response_interfaces(self): self._get_description(new.schema_object, True).lstrip("// "), ) new.schema_name = self._get_external_struct_name(new.interface) + new.schema_raw_name = rpc.operation_name # new.isRpcResponse = True self._api.external_new_methods.append(new) @@ -956,7 +964,7 @@ def _write_interface(self, new): return retObj }} {nil_call} - vErr := obj.validateFromText() + vErr := obj.validateFromText("FromPbText -> {obj_name}") if vErr != nil {{ return vErr }} @@ -998,7 +1006,7 @@ def _write_interface(self, new): uError.Error(), "\\u00a0", " ", -1)[7:]) }} {nil_call} - vErr := obj.validateFromText() + vErr := obj.validateFromText("FromYaml -> {obj_name}") if vErr != nil {{ return vErr }} @@ -1035,20 +1043,20 @@ def _write_interface(self, new): uError.Error(), "\\u00a0", " ", -1)[7:]) }} {nil_call} - err := obj.validateFromText() + err := obj.validateFromText("FromJson -> {obj_name}") if err != nil {{ return err }} return nil }} - func (obj *{struct}) validateFromText() error {{ - obj.validateObj(true) + func (obj *{struct}) validateFromText(path string) error {{ + obj.validateObj(true, path) return validationResult() }} func (obj *{struct}) Validate() error {{ - obj.validateObj(false) + obj.validateObj(false, "{obj_name}") return validationResult() }} @@ -1059,6 +1067,19 @@ def _write_interface(self, new): }} return str }} + + func (obj *{struct}) Clone() ({interface}, error) {{ + newObj := New{interface}() + pbText, err := obj.ToPbText() + if err != nil {{ + return nil, err + }} + pbErr := newObj.FromPbText(pbText) + if pbErr != nil {{ + return nil, pbErr + }} + return newObj, nil + }} """.format( struct=new.struct, pb_pkg_name=self._protobuf_package_name, @@ -1067,6 +1088,7 @@ def _write_interface(self, new): if len(internal_items) == 0 else "\n".join(internal_items), nil_call="obj.setNil()" if len(internal_items_nil) > 0 else "", + obj_name=new.schema_raw_name, ) ) if len(internal_items_nil) > 0: @@ -1097,8 +1119,10 @@ def _write_interface(self, new): "Validate() error", "// A stringer function", "String() string", - "validateFromText() error", - "validateObj(set_default bool)", + "// Clones the object", + "Clone() ({interface}, error)", + "validateFromText(path string) error", + "validateObj(set_default bool, path string)", "setDefault()", ] for field in new.interface_fields: @@ -1699,6 +1723,7 @@ def _build_setters_getters(self, fluent_new): field.schema = property_schema field.description = self._get_description(property_schema) field.name = self._get_external_field_name(property_name) + field.schema_name = property_name field.type = self._get_struct_field_type(property_schema, field) if ( len(choice_enums) == 1 @@ -1762,6 +1787,9 @@ def _build_setters_getters(self, fluent_new): schema_name = self._get_schema_object_name_from_ref( property_schema["$ref"] ) + field.schema_name = self._get_schema_json_name_from_ref( + property_schema["$ref"] + ) field.name = self._get_external_struct_name(schema_name) field.isOptional = fluent_new.isOptional(property_name) field.isPointer = ( @@ -1963,10 +1991,11 @@ def _validate_types(self, new, field): body = """ // {name} is required if obj.obj.{name}{enum} == {value} {{ - validation = append(validation, "{name} is required field on interface {interface}") + validation = append(validation, + fmt.Sprintf("required field `%s.{field_name}` must not be empty", path)) }} """.format( name=field.name, - interface=new.interface, + field_name=field.schema_name, value=0 if field.isEnum and field.isArray is False else value, enum=".Number()" if field.isEnum and field.isArray is False @@ -1990,12 +2019,13 @@ def _validate_types(self, new, field): + """ {{ validation = append( validation, - fmt.Sprintf("{min} <= {interface}.{name} <= {max} but Got {form}", {pointer}{value})) + fmt.Sprintf( + "length of field `%s.{name}` must be in range [{min}, {max}], instead of `{form}`", path, {pointer}{value})) }} """ ).format( - name=field.name, - interface=new.interface, + name=field.schema_name, + interface=new.schema_raw_name, max="max({})".format(field.type.lstrip("[]")) if field.max is None else field.max, @@ -2030,13 +2060,13 @@ def _validate_types(self, new, field): validation = append( validation, fmt.Sprintf( - "{min_length} <= length of {interface}.{name} <= {max_length} but Got %d", - len({pointer}{value}))) + "length of field `%s.{name}` must be in range [{min_length}, {max_length}], instead of `%d`", + path, len({pointer}{value}))) }} """ ).format( - name=field.name, - interface=new.interface, + name=field.schema_name, + interface=new.schema_name, max_length="any" if field.max_length is None else field.max_length, @@ -2065,13 +2095,13 @@ def _validate_types(self, new, field): if field.format is None: field.format = field.itemformat inner_body = """ - err := validate{format}(obj.{name}()) + err := validate{format}(obj.{name}(), fmt.Sprintf("%s.{field_name}", path)) if err != nil {{ - validation = append(validation, fmt.Sprintf("%s %s", err.Error(), "on {interface}.{name}")) + validation = append(validation, err.Error()) }} """.format( name=field.name, - interface=new.interface, + field_name=field.schema_name, format=field.format.capitalize() if field.isArray is False else field.format.capitalize() + "Slice", @@ -2094,14 +2124,19 @@ def _validate_struct(self, new, field): body = """ // {name} is required if obj.obj.{name} == nil {{ - validation = append(validation, "{name} is required field on interface {interface}") + validation = append( + validation, + fmt.Sprintf("required field `%s.{field_name}` must not be empty", path)) }} """.format( - name=field.name, interface=new.interface + name=field.name, + field_name=field.schema_name, ) - inner_body = "obj.{external_name}().validateObj(set_default)".format( - external_name=self._get_external_struct_name(field.name) + inner_body = """obj.{external_name}().validateObj( + set_default, fmt.Sprintf("%s.{json_name}", path))""".format( + external_name=self._get_external_struct_name(field.name), + json_name=field.schema_name, ) if field.isArray: inner_body = """ @@ -2110,12 +2145,13 @@ def _validate_struct(self, new, field): for _, item := range obj.obj.{name} {{ obj.{name}().appendHolderSlice(&{field_internal_struct}{{obj: item}}) }} - }} - for _, item := range obj.{name}().Items() {{ - item.validateObj(set_default) + }} + for ind, item := range obj.{name}().Items() {{ + item.validateObj(set_default, fmt.Sprintf("%s.{field_name}[%d]", path, ind)) }} """.format( name=field.name, + field_name=field.schema_name, field_internal_struct=field.struct, ) body += """ @@ -2163,7 +2199,7 @@ def p(): p() body = "\n".join(statements) self._write( - """func (obj *{struct}) validateObj(set_default bool) {{ + """func (obj *{struct}) validateObj(set_default bool, path string) {{ if set_default {{ obj.setDefault() }} @@ -2361,6 +2397,15 @@ def _write_default_method(self, new): ) ) + def _get_schema_json_name_from_ref(self, ref): + final_piece = ref.split("/")[-1] + if "." in final_piece: + return final_piece.replace(".", "_").lower() + return self._lower_first_char(final_piece) + + def _lower_first_char(self, word): + return word[0].lower() + word[1:] + def _get_schema_object_name_from_ref(self, ref): final_piece = ref.split("/")[-1] return final_piece.replace(".", "") @@ -2413,6 +2458,7 @@ def _get_struct_field_type(self, property_schema, fluent_field=None): new = FluentNew() new.schema_object = schema_object new.schema_name = schema_object_name + new.schema_raw_name = self._get_schema_json_name_from_ref(ref) new.struct = self._get_internal_name(schema_object_name) new.interface = self._get_external_struct_name( schema_object_name diff --git a/openapiart/tests/json_configs/config.json b/openapiart/tests/json_configs/config.json index aaefa52e..77800e65 100644 --- a/openapiart/tests/json_configs/config.json +++ b/openapiart/tests/json_configs/config.json @@ -1,4 +1,8 @@ { + "required_object": { + "e_a": 10.1, + "e_b": 2245.1111 + }, "a": "asdf", "b": 1.1, "c": 1, diff --git a/openapiart/tests/test_add.py b/openapiart/tests/test_add.py index eff880ee..f1a58e76 100644 --- a/openapiart/tests/test_add.py +++ b/openapiart/tests/test_add.py @@ -12,8 +12,8 @@ def test_add(api): assert config.f.f_b == config.f._DEFAULTS["f_b"] g1 = config.g.add(name="unique list name", g_a="dkdkd", g_b=3, g_c=22.2) g1.g_d = "gdgdgd" - j = config.j.add() - j.j_b.f_a = "a" + jval = config.j.add() + jval.j_b.f_a = "a" print(config) assert config.g[0].choice == "g_d" yaml = config.serialize(encoding=config.YAML) diff --git a/openapiart/tests/test_formats.py b/openapiart/tests/test_formats.py index 80e36de7..b9f80f4f 100644 --- a/openapiart/tests/test_formats.py +++ b/openapiart/tests/test_formats.py @@ -20,10 +20,8 @@ def test_formats_bad_string(config, value): config.l.string_param = value try: config.deserialize(config.serialize(encoding=config.YAML)) - pytest.fail( - "Value {value} was successfully validated".format(value=value) - ) - except TypeError: + pytest.fail("Value {} was successfully validated".format(value)) + except Exception: pass @@ -33,7 +31,7 @@ def test_formats_bad_integer(config, value): try: config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -43,7 +41,7 @@ def test_formats_integer_to_be_removed(config, value): config.l.integer = value config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -52,7 +50,7 @@ def test_formats_good_ipv4(config, value): config.l.ipv4 = value try: config.deserialize(config.serialize(encoding=config.YAML)) - except TypeError: + except Exception: pytest.fail("Value {} was not valid".format(value)) @@ -75,7 +73,7 @@ def test_formats_bad_ipv4(config, value): try: config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -85,7 +83,7 @@ def test_formats_ipv4_to_be_removed(config, value): config.l.ipv4 = value config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -108,7 +106,7 @@ def test_formats_bad_ipv6(config, value): try: config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -130,7 +128,7 @@ def test_formats_bad_mac(config, value): try: config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -142,7 +140,7 @@ def test_formats_bad_hex(config, value): try: config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass @@ -152,7 +150,7 @@ def test_string_length(config, value): try: config.deserialize(config.serialize(encoding=config.YAML)) pytest.fail("Value {} was successfully validated".format(value)) - except TypeError: + except Exception: pass diff --git a/openapiart/tests/test_func.py b/openapiart/tests/test_func.py index d2a34bbe..a9097b40 100644 --- a/openapiart/tests/test_func.py +++ b/openapiart/tests/test_func.py @@ -20,7 +20,7 @@ def test_required(api): config.mandatory config.mandatory.serialize() pytest.fail("config got validated") - except ValueError: + except Exception: pass @@ -103,8 +103,8 @@ def test_x_pattern_ipv4_good_and_bad_list(default_config, ipv4): try: default_config.ipv4_pattern.serialize(default_config.DICT) pytest.fail("ipv4 values got serialize") - except TypeError as e: - if "['-255.-255.-255.-255']" not in str(e): + except Exception as e: + if "`-255.-255.-255.-255`" not in str(e): pytest.fail("Invalid ipv4 list is not proper in error message") @@ -114,8 +114,8 @@ def test_x_pattern_ipv6_good_and_bad_list(default_config, ipv6): try: default_config.ipv6_pattern.serialize(default_config.DICT) pytest.fail("ipv6 values got serialize") - except TypeError as e: - if "[':', 'abcd::abcd::']" not in str(e): + except Exception as e: + if "`abcd::abcd::`" not in str(e) or "`:`" not in str(e): pytest.fail("Invalid ipv6 list is not proper in error message") @@ -125,8 +125,8 @@ def test_x_pattern_mac_good_and_bad_list(default_config, mac): try: default_config.mac_pattern.serialize(default_config.DICT) pytest.fail("mac values got serialize") - except TypeError as e: - if "[':', 'abcd::abcd::']" not in str(e): + except Exception as e: + if "`abcd::abcd::`" not in str(e) or "`:`" not in str(e): pytest.fail("Invalid mac list is not proper in error message") @@ -138,8 +138,12 @@ def test_x_pattern_integer_good_and_bad_list(default_config, integer): try: default_config.integer_pattern.serialize(default_config.DICT) pytest.fail("integer values got serialize") - except TypeError as e: - if "['abcd::abcd::', 256, 'ab:ab:ab:ab:ab:ab']" not in str(e): + except Exception as e: + if ( + "`abcd::abcd::`" not in str(e) + or "`256`" not in str(e) + or "`ab:ab:ab:ab:ab:ab`" not in str(e) + ): pytest.fail("Invalid integer list is not proper in error message") @@ -159,7 +163,7 @@ def test_x_pattern_good_inc_dec(default_config, index, direction): dir_obj.count = count[index] try: default_config.serialize(default_config.DICT) - except TypeError: + except Exception: pytest.fail("%s with %s Failed to serialize" % (enum, direction)) @@ -180,7 +184,7 @@ def test_x_pattern_bad_inc_dec(default_config, index, direction): try: default_config.serialize(default_config.DICT) pytest.fail("%s with %s got serialized" % (enum, direction)) - except TypeError as e: + except Exception as e: print(e) @@ -196,7 +200,15 @@ def test_int_64_format(api, default_config): conf.integer64 = "2000" try: conf.validate() - except TypeError as e: + except Exception as e: + print(e) + + +def test_enum_setter(api, default_config): + default_config.response = "abc" + try: + default_config.validate() + except Exception as e: print(e) diff --git a/openapiart/tests/test_py_go_diff.py b/openapiart/tests/test_py_go_diff.py new file mode 100644 index 00000000..eaaa7501 --- /dev/null +++ b/openapiart/tests/test_py_go_diff.py @@ -0,0 +1,97 @@ +import importlib +import pytest + +module = importlib.import_module("sanity") + + +def test_iter_set_method(default_config): + default_config.j.add() + default_config.j.set(0, module.JObject()) + assert len(default_config.j) == 1 + try: + default_config.j.append(module.FObject()) + pytest.fail("appending an invalid object is not throwing exception") + except Exception: + pass + try: + default_config.j.set(0, module.FObject()) + pytest.fail("setting an invalid object is not throwing exception") + except Exception: + pass + + assert isinstance(default_config.j[0], module.EObject) + + +def test_validation_errors(): + p = module.Api().prefix_config() + p.e + try: + p.validate() + pytest.fail + except Exception as e: + assert "required field `prefix_config.a` must not be empty" in str(e) + assert "required field `prefix_config.b` must not be empty" in str(e) + assert "required field `prefix_config.c` must not be empty" in str(e) + assert ( + "required field `prefix_config.required_object` must not be empty" + in str(e) + ) + assert "required field `prefix_config.e.e_a` must not be empty" in str( + e + ) + assert "required field `prefix_config.e.e_b` must not be empty" in str( + e + ) + + p.e.e_a = "abc" + try: + p.validate() + except Exception as e: + print(e) + assert ( + "value of `prefix_config.e.e_a` must be a valid float type, instead of `abc`" + in str(e) + ) + p.a = "abc" + p.b = 10.1 + p.c = 20 + p.required_object.e_a = 10.1 + p.required_object.e_b = 20 + p.j.add().j_a + p.mac_pattern.mac.values = ["1", "20"] + p.ipv4_pattern.ipv4.value = "1.1" + errors = p._validate(p._JSON_NAME, True) + assert len([True for e in errors if ".e_b` must not be empty" in e]) == 2 + assert ( + "required field `prefix_config.j[0].e_a` must not be empty" in errors + ) + assert "required field `prefix_config.e.e_b` must not be empty" in errors + assert ( + "value of `prefix_config.e.e_a` must be a valid float type, instead of `abc`" + in errors + ) + assert ( + "required field `prefix_config.j[0].e_a` must not be empty" in errors + ) + assert ( + "required field `prefix_config.j[0].e_b` must not be empty" in errors + ) + assert ( + "value of `prefix_config.mac_pattern.mac.values[0]` must be a valid mac address, instead of `1`" + in errors + ) + assert ( + "value of `prefix_config.mac_pattern.mac.values[1]` must be a valid mac address, instead of `20`" + in errors + ) + assert ( + "value of `prefix_config.ipv4_pattern.ipv4.value` must be a valid ipv4 address, instead of `1.1`" + in errors + ) + + +def test_enum_setter(): + p = module.Api().prefix_config() + p.response = "abc" + errors = p._validate(p._JSON_NAME, True) + assert "abc is not a valid enum for property response" in errors diff --git a/pkg/common.go b/pkg/common.go index c7eadf1a..df86b0ff 100644 --- a/pkg/common.go +++ b/pkg/common.go @@ -164,48 +164,46 @@ func validationResult() error { return nil } -func validateMac(mac string) error { +func validateMac(mac string, path string) error { macSlice := strings.Split(mac, ":") if len(macSlice) != 6 { - return fmt.Errorf(fmt.Sprintf("Invalid Mac address %s", mac)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid mac string, instead of `%s`", path, mac)) } - octInd := []string{"0th", "1st", "2nd", "3rd", "4th", "5th"} - for ind, val := range macSlice { + for _, val := range macSlice { num, err := strconv.ParseUint(val, 16, 32) if err != nil || num > 255 { - return fmt.Errorf(fmt.Sprintf("Invalid Mac address at %s octet in %s mac", octInd[ind], mac)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid mac string, instead of `%s`", path, mac)) } } return nil } -func validateIpv4(ip string) error { +func validateIpv4(ip string, path string) error { ipSlice := strings.Split(ip, ".") if len(ipSlice) != 4 { - return fmt.Errorf(fmt.Sprintf("Invalid Ipv4 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv4 string, instead of `%s`", path, ip)) } - octInd := []string{"1st", "2nd", "3rd", "4th"} - for ind, val := range ipSlice { + for _, val := range ipSlice { num, err := strconv.ParseUint(val, 10, 32) if err != nil || num > 255 { - return fmt.Errorf(fmt.Sprintf("Invalid Ipv4 address at %s octet in %s ipv4", octInd[ind], ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv4 string, instead of `%s`", path, ip)) } } return nil } -func validateIpv6(ip string) error { +func validateIpv6(ip string, path string) error { ip = strings.Trim(ip, " \t") if strings.Count(ip, " ") > 0 || strings.Count(ip, ":") > 7 || strings.Count(ip, "::") > 1 || strings.Count(ip, ":::") > 0 || strings.Count(ip, ":") == 0 { - return fmt.Errorf(fmt.Sprintf("Invalid ipv6 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } if (string(ip[0]) == ":" && string(ip[:2]) != "::") || (string(ip[len(ip)-1]) == ":" && string(ip[len(ip)-2:]) != "::") { - return fmt.Errorf(fmt.Sprintf("Invalid ipv6 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } if strings.Count(ip, "::") == 0 && strings.Count(ip, ":") != 7 { - return fmt.Errorf(fmt.Sprintf("Invalid ipv6 address %s", ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } if ip == "::" { return nil @@ -220,69 +218,69 @@ func validateIpv6(ip string) error { r := strings.NewReplacer("::", ":0:") ip = r.Replace(ip) } - octInd := []string{"1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"} ipSlice := strings.Split(ip, ":") - for ind, val := range ipSlice { + for _, val := range ipSlice { num, err := strconv.ParseUint(val, 16, 64) if err != nil || num > 65535 { - return fmt.Errorf(fmt.Sprintf("Invalid Ipv6 address at %s octet in %s ipv6", octInd[ind], ip)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid ipv6 string, instead of `%s`", path, ip)) } } return nil } -func validateHex(hex string) error { +func validateHex(hex string, path string) error { matched, err := regexp.MatchString(`^[0-9a-fA-F]+$|^0[x|X][0-9a-fA-F]+$`, hex) if err != nil || !matched { - return fmt.Errorf(fmt.Sprintf("Invalid hex value %s", hex)) + return fmt.Errorf(fmt.Sprintf("value of `%s` must be a valid hex string, instead of %s", path, hex)) } return nil } -func validateSlice(valSlice []string, sliceType string) error { +func validateSlice(valSlice []string, sliceType string, path string) error { indices := []string{} var err error for i, val := range valSlice { if sliceType == "mac" { - err = validateMac(val) + err = validateMac(val, path) } else if sliceType == "ipv4" { - err = validateIpv4(val) + err = validateIpv4(val, path) } else if sliceType == "ipv6" { - err = validateIpv6(val) + err = validateIpv6(val, path) } else if sliceType == "hex" { - err = validateHex(val) + err = validateHex(val, path) } else { - return fmt.Errorf(fmt.Sprintf("Invalid slice type received <%s>", sliceType)) + return fmt.Errorf(fmt.Sprintf("invalid slice type received <%s>", sliceType)) } if err != nil { - indices = append(indices, fmt.Sprintf("%d", i)) + indices = append(indices, + fmt.Sprintf("value of `%s[%d]` must be a valid %s string, instead of `%s`", path, i, sliceType, val)) } } if len(indices) > 0 { return fmt.Errorf( - fmt.Sprintf("Invalid %s addresses at indices %s", sliceType, strings.Join(indices, ",")), + strings.Join(indices, "\n"), ) } return nil } -func validateMacSlice(mac []string) error { - return validateSlice(mac, "mac") +func validateMacSlice(mac []string, path string) error { + return validateSlice(mac, "mac", path) } -func validateIpv4Slice(ip []string) error { - return validateSlice(ip, "ipv4") +func validateIpv4Slice(ip []string, path string) error { + return validateSlice(ip, "ipv4", path) } -func validateIpv6Slice(ip []string) error { - return validateSlice(ip, "ipv6") +func validateIpv6Slice(ip []string, path string) error { + return validateSlice(ip, "ipv6", path) } -func validateHexSlice(hex []string) error { - return validateSlice(hex, "hex") +func validateHexSlice(hex []string, path string) error { + return validateSlice(hex, "hex", path) } diff --git a/pkg/generated_required_test.go b/pkg/generated_required_test.go index d891c9d2..3b065e75 100644 --- a/pkg/generated_required_test.go +++ b/pkg/generated_required_test.go @@ -20,11 +20,16 @@ func TestPrefixConfigRequired(t *testing.T) { data, _ := opts.Marshal(object.Msg()) err := object.FromJson(string(data)) err1 := object.FromYaml(string(data)) - protoMarshal, _ := proto.Marshal(object.Msg()) - err2 := object.FromPbText(string(protoMarshal)) - assert.Contains(t, err.Error(), "RequiredObject", "A", "B", "C") - assert.Contains(t, err1.Error(), "RequiredObject", "A", "B", "C") - assert.Contains(t, err2.Error(), "RequiredObject", "A", "B", "C") + str, _ := proto.Marshal(object.Msg()) + err2 := object.FromPbText(string(str)) + assert.Contains(t, err.Error(), "prefix_config.required_object") + assert.Contains(t, err.Error(), "prefix_config.a") + assert.Contains(t, err1.Error(), "prefix_config.required_object") + assert.Contains(t, err1.Error(), "prefix_config.a") + assert.Contains(t, err2.Error(), "prefix_config.required_object") + assert.Contains(t, err2.Error(), "prefix_config.a") + // assert.Contains(t, err2.Error(), "PrefixConfig.b") + // assert.Contains(t, err2.Error(), "PrefixConfig.c") } // func TestEObjectRequired(t *testing.T) { @@ -54,11 +59,11 @@ func TestMandateRequired(t *testing.T) { data, _ := opts.Marshal(object.Msg()) err := object.FromJson(string(data)) err1 := object.FromYaml(string(data)) - protoMarshal, _ := proto.Marshal(object.Msg()) - err2 := object.FromPbText(string(protoMarshal)) - assert.Contains(t, err.Error(), "RequiredParam") - assert.Contains(t, err1.Error(), "RequiredParam") - assert.Contains(t, err2.Error(), "RequiredParam") + str, _ := proto.Marshal(object.Msg()) + err2 := object.FromPbText(string(str)) + assert.Contains(t, err.Error(), "mandate.required_param") + assert.Contains(t, err1.Error(), "mandate.required_param") + assert.Contains(t, err2.Error(), "mandate.required_param") } func TestMObjectRequired(t *testing.T) { object := openapiart.NewMObject() @@ -71,11 +76,32 @@ func TestMObjectRequired(t *testing.T) { data, _ := opts.Marshal(object.Msg()) err := object.FromJson(string(data)) err1 := object.FromYaml(string(data)) - protoMarshal, _ := proto.Marshal(object.Msg()) - err2 := object.FromPbText(string(protoMarshal)) - assert.Contains(t, err.Error(), "StringParam", "Integer", "Float", "Double", "Mac", "Ipv4", "Ipv6", "Hex") - assert.Contains(t, err1.Error(), "StringParam", "Integer", "Float", "Double", "Mac", "Ipv4", "Ipv6", "Hex") - assert.Contains(t, err2.Error(), "StringParam", "Integer", "Float", "Double", "Mac", "Ipv4", "Ipv6", "Hex") + str, _ := proto.Marshal(object.Msg()) + err2 := object.FromPbText(string(str)) + assert.Contains(t, err.Error(), "mObject.string_param") + // assert.Contains(t, err.Error(), "MObject.integer") + assert.Contains(t, err.Error(), "mObject.ipv4") + assert.Contains(t, err.Error(), "mObject.mac") + // assert.Contains(t, err.Error(), "MObject.float") + // assert.Contains(t, err.Error(), "MObject.double") + assert.Contains(t, err.Error(), "mObject.ipv6") + assert.Contains(t, err.Error(), "mObject.hex") + assert.Contains(t, err1.Error(), "mObject.string") + // assert.Contains(t, err1.Error(), "MObject.integer") + assert.Contains(t, err1.Error(), "mObject.ipv4") + assert.Contains(t, err1.Error(), "mObject.mac") + // assert.Contains(t, err1.Error(), "MObject.float") + // assert.Contains(t, err1.Error(), "MObject.double") + assert.Contains(t, err1.Error(), "mObject.ipv6") + assert.Contains(t, err1.Error(), "mObject.hex") + assert.Contains(t, err2.Error(), "mObject.string") + // assert.Contains(t, err2.Error(), "MObject.integer") + assert.Contains(t, err2.Error(), "mObject.ipv4") + assert.Contains(t, err2.Error(), "mObject.mac") + // assert.Contains(t, err2.Error(), "MObject.float") + // assert.Contains(t, err2.Error(), "MObject.double") + assert.Contains(t, err2.Error(), "mObject.ipv6") + assert.Contains(t, err2.Error(), "mObject.hex") } func TestPortMetricRequired(t *testing.T) { object := openapiart.NewPortMetric() @@ -88,9 +114,9 @@ func TestPortMetricRequired(t *testing.T) { data, _ := opts.Marshal(object.Msg()) err := object.FromJson(string(data)) err1 := object.FromYaml(string(data)) - protoMarshal, _ := proto.Marshal(object.Msg()) - err2 := object.FromPbText(string(protoMarshal)) - assert.Contains(t, err.Error(), "Name", "TxFrames", "RxFrames") - assert.Contains(t, err1.Error(), "Name", "TxFrames", "RxFrames") - assert.Contains(t, err2.Error(), "Name", "TxFrames", "RxFrames") + str, _ := proto.Marshal(object.Msg()) + err2 := object.FromPbText(string(str)) + assert.Contains(t, err.Error(), "port_metric.name") + assert.Contains(t, err1.Error(), "port_metric.name") + assert.Contains(t, err2.Error(), "port_metric.name") } diff --git a/pkg/unit_test.go b/pkg/unit_test.go index f27c1ea8..b9df67b1 100644 --- a/pkg/unit_test.go +++ b/pkg/unit_test.go @@ -381,7 +381,7 @@ func TestBadMacValidation(t *testing.T) { macObj := config.MacPattern().Mac().SetValue(mac) err := macObj.Validate() if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Invalid Mac") + assert.Contains(t, err.Error(), "value of `pattern_macpattern_mac.value` must be a valid mac") } } } @@ -401,7 +401,7 @@ func TestBadMacValues(t *testing.T) { err := mac.Validate() fmt.Println(err.Error()) if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid mac address") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid mac string") } } @@ -414,7 +414,7 @@ func TestBadMacIncrement(t *testing.T) { err := mac.Validate() fmt.Println(err.Error()) if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid mac address") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid mac string") } } @@ -427,7 +427,7 @@ func TestBadMacDecrement(t *testing.T) { err := mac.Validate() fmt.Println(err.Error()) if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid mac address") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid mac string") } } @@ -449,7 +449,7 @@ func TestBadIpv4Validation(t *testing.T) { ipv4 := config.Ipv4Pattern().Ipv4().SetValue(ip) err := ipv4.Validate() if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Invalid Ipv4") + assert.Contains(t, err.Error(), "must be a valid ipv4 string") } } } @@ -460,7 +460,7 @@ func TestBadIpv4Values(t *testing.T) { ipv4 := config.Ipv4Pattern().Ipv4().SetValues(BadIpv4) err := ipv4.Validate() if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Invalid ipv4 addresses") + assert.Contains(t, err.Error(), "must be a valid ipv4 string") } } @@ -472,7 +472,7 @@ func TestBadIpv4Increment(t *testing.T) { ipv4.SetCount(10) err := ipv4.Validate() if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Invalid Ipv4") + assert.Contains(t, err.Error(), "must be a valid ipv4 string") } } @@ -484,7 +484,7 @@ func TestBadIpv4Decrement(t *testing.T) { ipv4.SetCount(10) err := ipv4.Validate() if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Invalid Ipv4") + assert.Contains(t, err.Error(), "must be a valid ipv4 string") } } @@ -506,7 +506,7 @@ func TestBadIpv6Validation(t *testing.T) { ipv6 := config.Ipv6Pattern().Ipv6().SetValue(ip) err := ipv6.Validate() if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid ipv6") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid ipv6 string") } } } @@ -517,7 +517,7 @@ func TestBadIpv6Values(t *testing.T) { ipv6 := config.Ipv6Pattern().Ipv6().SetValues(BadIpv6) err := ipv6.Validate() if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid ipv6 address") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid ipv6 string") } } @@ -529,7 +529,7 @@ func TestBadIpv6Increment(t *testing.T) { ipv6.SetCount(10) err := ipv6.Validate() if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid ipv6") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid ipv6 string") } } @@ -541,7 +541,7 @@ func TestBadIpv6Decrement(t *testing.T) { ipv6.SetCount(10) err := ipv6.Validate() if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "invalid ipv6") + assert.Contains(t, strings.ToLower(err.Error()), "must be a valid ipv6 string") } } @@ -670,7 +670,7 @@ func TestRequiredField(t *testing.T) { mandate := openapiart.NewMandate() err := mandate.Validate() assert.NotNil(t, err) - assert.Contains(t, err.Error(), "RequiredParam is required field") + assert.Contains(t, err.Error(), "required field `mandate.required_param` must not be empty") } func TestOptionalDefault(t *testing.T) { @@ -757,7 +757,7 @@ func TestFromJsonToCleanObject(t *testing.T) { }` err1 := config.FromJson(new_json1) assert.NotNil(t, err1) - assert.Contains(t, err1.Error(), "A is required field") + assert.Contains(t, err1.Error(), "required field `FromJson -> prefix_config.a` must not be empty") } func TestChoiceStale(t *testing.T) { @@ -980,7 +980,8 @@ func TestStringLengthError(t *testing.T) { config.Name() err := config.Validate() if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "3 <= length of prefixconfig.strlen <= 6 but got 8") + assert.Contains(t, err.Error(), + "length of field `prefix_config.str_len` must be in range [3, 6], instead of `8`") } } @@ -1015,7 +1016,7 @@ func TestMObjectValidation(t *testing.T) { mObject := openapiart.NewMObject() err := mObject.Validate() if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), "required field on interface mobject") + assert.Contains(t, err.Error(), "required field") } } @@ -1037,13 +1038,14 @@ func TestMobjectValidationError(t *testing.T) { SetIpv4("1.1.1.1.2") config.SetResponse(openapiart.PrefixConfigResponse.STATUS_400) err := config.Validate() + fmt.Println(err.Error()) assert.NotNil(t, err) if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), - "invalid mac address", - "invalid ipv4 address", - "invalid hex value", - "invalid ipv6 address") + assert.Contains(t, err.Error(), "must be in range [10, 90], instead of `120`") + assert.Contains(t, err.Error(), "must be a valid hex") + assert.Contains(t, err.Error(), "must be a valid mac") + assert.Contains(t, err.Error(), "must be a valid ipv4") + assert.Contains(t, err.Error(), "must be a valid ipv6") } } @@ -1062,11 +1064,10 @@ func TestLObjectError(t *testing.T) { err := config.Validate() assert.NotNil(t, err) if assert.Error(t, err) { - assert.Contains(t, strings.ToLower(err.Error()), - "invalid mac address", - "invalid ipv4 address", - "invalid hex value", - "invalid ipv6 address") + assert.Contains(t, err.Error(), "must be a valid hex") + assert.Contains(t, err.Error(), "must be a valid mac") + assert.Contains(t, err.Error(), "must be a valid ipv4") + assert.Contains(t, err.Error(), "must be a valid ipv6") } }