diff options
Diffstat (limited to 'gnu/packages/patches/python-pytorch-fix-codegen.patch')
-rw-r--r-- | gnu/packages/patches/python-pytorch-fix-codegen.patch | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/gnu/packages/patches/python-pytorch-fix-codegen.patch b/gnu/packages/patches/python-pytorch-fix-codegen.patch index 106ea7db66..3862339b14 100644 --- a/gnu/packages/patches/python-pytorch-fix-codegen.patch +++ b/gnu/packages/patches/python-pytorch-fix-codegen.patch @@ -6,7 +6,7 @@ is later corrected. codegen_external.py is patched to avoid duplicate functions and add the static keyword as in the existing generated file. diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh -index cc0263dbbf..ac34e84b82 100644 +index cc0263dbb..ac34e84b8 100644 --- a/tools/gen_flatbuffers.sh +++ b/tools/gen_flatbuffers.sh @@ -1,13 +1,13 @@ @@ -32,7 +32,7 @@ index cc0263dbbf..ac34e84b82 100644 -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs" echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h" diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py -index 5dcf1b2840..0e20b0c102 100644 +index 5dcf1b284..0e20b0c10 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py @@ -21,9 +21,14 @@ def gen_external(native_functions_path, tags_path, external_path): @@ -61,7 +61,7 @@ index 5dcf1b2840..0e20b0c102 100644 void** buf_data, int64_t* buf_ranks, diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py -index b42948045c..e1cfc73a5e 100644 +index b42948045..e1cfc73a5 100644 --- a/torchgen/decompositions/gen_jit_decompositions.py +++ b/torchgen/decompositions/gen_jit_decompositions.py @@ -1,8 +1,12 @@ @@ -88,7 +88,7 @@ index b42948045c..e1cfc73a5e 100644 write_decomposition_util_file(str(upgrader_path)) diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py -index 362ce427d5..245056f815 100644 +index 845034cb7..a1c5767c2 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -6,10 +6,13 @@ import os @@ -107,9 +107,9 @@ index 362ce427d5..245056f815 100644 from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( MOBILE_UPGRADERS_HEADER_DESCRIPTION, -@@ -265,7 +268,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: +@@ -263,7 +266,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: def construct_version_maps( - upgrader_bytecode_function_to_index_map: dict[str, Any] + upgrader_bytecode_function_to_index_map: dict[str, Any], ) -> str: - version_map = torch._C._get_operator_version_map() + if len(sys.argv) < 2 or sys.argv[1] != "dummy": @@ -119,7 +119,7 @@ index 362ce427d5..245056f815 100644 sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] sorted_version_map = dict(sorted_version_map_) -@@ -381,7 +387,10 @@ def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: +@@ -375,7 +381,10 @@ def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: def main() -> None: @@ -132,7 +132,7 @@ index 362ce427d5..245056f815 100644 for up in sorted_upgrader_list: print("after sort upgrader : ", next(iter(up))) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py -index 56a3d8bf0d..490a3ea2e7 100644 +index 56a3d8bf0..ffd0785fd 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -1,6 +1,7 @@ @@ -143,7 +143,7 @@ index 56a3d8bf0d..490a3ea2e7 100644 from importlib.util import module_from_spec, spec_from_file_location from itertools import chain from pathlib import Path -@@ -18,16 +19,21 @@ you are in the root directory of the Pytorch git repo""" +@@ -18,17 +19,21 @@ you are in the root directory of the Pytorch git repo""" if not file_path.exists(): raise Exception(err_msg) # noqa: TRY002 @@ -157,6 +157,7 @@ index 56a3d8bf0d..490a3ea2e7 100644 - -bounded_compute_graph_mapping = module.bounded_compute_graph_mapping -shape_compute_graph_mapping = module.shape_compute_graph_mapping +- +if len(sys.argv) < 2 or sys.argv[1] != "dummy": + spec = importlib.util.spec_from_file_location(module_name, file_path) + assert spec is not None @@ -173,5 +174,5 @@ index 56a3d8bf0d..490a3ea2e7 100644 + bounded_compute_graph_mapping = {} + shape_compute_graph_mapping = {} - SHAPE_HEADER = r""" + /** |